From 3e5e08b0b4fc0437ed2042e67c867bdb2003c248 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 3 Mar 2022 19:29:39 +0800 Subject: [PATCH] perf(mge/functional): speed up RemoveAxis GitOrigin-RevId: 9c5d27fe1d19bb80848ee244f7c75f309b58df6a --- .../megengine/core/tensor/array_method.py | 25 +----- .../python/megengine/functional/tensor.py | 5 +- imperative/python/src/tensor.cpp | 2 + imperative/python/src/tensor_utils.cpp | 76 +++++++++++++++++-- imperative/python/src/tensor_utils.h | 2 + 5 files changed, 80 insertions(+), 30 deletions(-) diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 9072f8e68..51d6becdd 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -17,6 +17,7 @@ from .. import _config from .._imperative_rt.common import CompNode from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar +from .._imperative_rt.core2 import squeeze_cpp from ..ops import builtin from . import amp from .indexing import getitem, setitem @@ -448,26 +449,6 @@ def _logical_binary_elwise(mode, rev=False): return f -def _remove_axis(inp: Tensor, axis) -> Tensor: - def get_axes(): - if axis is None: - shp = inp.shape - return [i for i, s in enumerate(shp) if s == 1] - try: - return [int(axis)] - except (TypeError, ValueError): - pass - return list(map(int, axis)) - - axis = get_axes() - axis = _normalize_axis(inp.ndim, axis) - axis = [a - i for i, a in enumerate(axis)] - - op = builtin.RemoveAxis(axis=axis) - (result,) = apply(op, inp) - return result - - def _reduce(mode): def f(self, axis=None, keepdims: bool = False): data = self @@ -480,7 +461,7 @@ def _reduce(mode): op = builtin.Reduce(mode=mode, axis=ai) (data,) = apply(op, data) if not keepdims: - data = _remove_axis(data, ai) + data = squeeze_cpp(data, ai) result = data else: # builtin.Reduce already accept negtive axis @@ -488,7 +469,7 @@ def _reduce(mode): (result,) = apply(op, data) if not keepdims: - result = _remove_axis(result, axis) + result = squeeze_cpp(result, axis) return result return f diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 689e0f0a0..20ad39db4 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -18,12 +18,13 @@ from ..core._imperative_rt.core2 import ( dtype_promotion, expand_dims_cpp, split_cpp, + squeeze_cpp, ) from ..core._wrap import as_device from ..core.ops import builtin from ..core.ops.builtin import Copy, Identity from ..core.ops.special import Const -from ..core.tensor.array_method import _broadcast, _remove_axis +from ..core.tensor.array_method import _broadcast from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn from ..device import get_default_device from ..tensor import Tensor @@ -996,7 +997,7 @@ def squeeze(inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None) -> Te (1, 1, 2) """ - return _remove_axis(inp, axis) + return squeeze_cpp(inp, axis) def linspace( diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 6efcb0f7e..f271a6896 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -635,6 +635,7 @@ WRAP_FUNC_PY35(getitem_cpp); WRAP_FUNC_PY35(setitem_cpp); WRAP_FUNC_PY35(split_cpp); WRAP_FUNC_PY35(expand_dims_cpp); +WRAP_FUNC_PY35(squeeze_cpp); #undef WRAP_FUNC_PY35 #define MGE_PY_INTERFACE(NAME, FUNC) \ { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } @@ -769,6 +770,7 @@ void init_tensor(py::module m) { MGE_PY_INTERFACE(setitem_cpp, setitem_cpp), MGE_PY_INTERFACE(split_cpp, split_cpp), MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp), + MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp), {nullptr, nullptr, 0, nullptr}}; for (auto&& def : method_defs) { if (def.ml_meth != nullptr) { diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 97798fc09..b46549c8d 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -683,17 +683,21 @@ py::object _split_cpp( return py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); } -py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { +std::vector list2vector(py::handle li) { std::vector axis; - if (is_py_sequence(axis_hdl.ptr())) { - py::list tmp_list = - py::reinterpret_steal(PySequence_List(axis_hdl.ptr())); + if (is_py_sequence(li.ptr())) { + py::list tmp_list = py::reinterpret_steal(PySequence_List(li.ptr())); for (size_t i = 0; i < tmp_list.size(); ++i) { axis.push_back(tmp_list[i].attr("__int__")().cast()); } } else { - axis.push_back(getattr(axis_hdl, "__int__")().cast()); + axis.push_back(getattr(li, "__int__")().cast()); } + return axis; +} + +py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { + std::vector axis = list2vector(axis_hdl); bool unknown_ndim = true; size_t ndim = axis.size(); if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) { @@ -718,7 +722,7 @@ py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { "Does not support negative index when tensor's ndim is " "unknown"); } - axis[i] += ndim; + axis[i] += static_cast(ndim); } } if (!axis.size()) { @@ -736,6 +740,59 @@ py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) { return ret[0]; } +py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) { + std::vector axis; + size_t ndim; + if (axis_hdl.ptr() != Py_None) { + axis = list2vector(axis_hdl); + } + if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) { + auto&& shape = p->m_tensor->shape(); + if (shape) { + ndim = shape->ndim; + if (axis_hdl.ptr() == Py_None) { + for (size_t i = 0; i < shape->ndim; ++i) { + if (shape->shape[i] == 1) { + axis.push_back(i); + } + } + } + } + } else { + auto&& var = inp_hdl.cast(); + auto&& mgr = var->m_node->owner_graph()->static_infer_manager(); + auto&& shape = mgr.infer_shape_fallible(var->m_node); + if (shape) { + ndim = shape->ndim; + if (axis_hdl.ptr() == Py_None) { + for (size_t i = 0; i < shape->ndim; ++i) { + if (shape->shape[i] == 1) { + axis.push_back(i); + } + } + } + } + } + for (size_t i = 0; i < axis.size(); ++i) { + if (axis[i] < 0) { + axis[i] += static_cast(ndim); + } + } + std::sort(axis.begin(), axis.end()); + for (size_t i = 0; i < axis.size(); ++i) { + axis[i] -= static_cast(i); + } + std::shared_ptr op = RemoveAxis::make(axis = axis); + std::vector p; + p.resize(2); + py::object Op = py::cast(op); + p[0] = Op.ptr(); + p[1] = inp_hdl.ptr(); + py::tuple ret = + py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); + return ret[0]; +} + PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { try { return _make_shape_tuple(py::handle(args[0])).release().ptr(); @@ -778,4 +835,11 @@ PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs) { PYEXT17_TRANSLATE_EXC_RET(nullptr) } +PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs) { + try { + return _squeeze_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); + } + PYEXT17_TRANSLATE_EXC_RET(nullptr) +} + } // namespace mgb::imperative::python diff --git a/imperative/python/src/tensor_utils.h b/imperative/python/src/tensor_utils.h index bf6bc0468..55ea2c7d7 100644 --- a/imperative/python/src/tensor_utils.h +++ b/imperative/python/src/tensor_utils.h @@ -12,4 +12,6 @@ PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs); PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs); +PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs); + } // namespace mgb::imperative::python \ No newline at end of file -- GitLab