提交 0f736a0a 编写于 作者: M Megvii Engine Team

perf(mge/functional): speed up Dimshuffle

GitOrigin-RevId: 8160c9522bc59d06aab1321b87408ecb410b4a81
上级 3e5e08b0
......@@ -17,7 +17,7 @@ from .. import _config
from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion
from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar
from .._imperative_rt.core2 import squeeze_cpp
from .._imperative_rt.core2 import squeeze_cpp, transpose_cpp
from ..ops import builtin
from . import amp
from .indexing import getitem, setitem
......@@ -331,12 +331,6 @@ def _matmul(
return result
def _transpose(data, axes):
op = builtin.Dimshuffle(axes)
(result,) = apply(op, data)
return result
def _broadcast(inp, shape):
auto_infer = False
if isinstance(shape, (list, tuple)):
......@@ -681,15 +675,7 @@ class ArrayMethodMixin(abc.ABC):
def transpose(self, *args):
r"""See :func:`~.transpose`."""
if self.ndim == 0:
assert (
len(args) == 0
), "transpose for scalar does not accept additional args"
ret = self.to(self.device)
return ret
if not args:
args = range(self.ndim)[::-1]
return _transpose(self, _expand_args(args))
return transpose_cpp(self, args)
def flatten(self):
r"""See :func:`~.flatten`."""
......
......@@ -865,7 +865,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
[[1 0]
[1 0]]
"""
return inp.transpose(list(-1 if _ == "x" else _ for _ in pattern))
return inp.transpose(pattern)
def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
......
......@@ -636,6 +636,7 @@ WRAP_FUNC_PY35(setitem_cpp);
WRAP_FUNC_PY35(split_cpp);
WRAP_FUNC_PY35(expand_dims_cpp);
WRAP_FUNC_PY35(squeeze_cpp);
WRAP_FUNC_PY35(transpose_cpp);
#undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
......@@ -771,6 +772,7 @@ void init_tensor(py::module m) {
MGE_PY_INTERFACE(split_cpp, split_cpp),
MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp),
MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp),
MGE_PY_INTERFACE(transpose_cpp, transpose_cpp),
{nullptr, nullptr, 0, nullptr}};
for (auto&& def : method_defs) {
if (def.ml_meth != nullptr) {
......
......@@ -793,6 +793,57 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) {
return ret[0];
}
size_t fast_ndim(py::handle tensor) {
if (auto p = TensorWrapper::try_cast(tensor.ptr())) {
return p->m_tensor->shape()->ndim;
}
return getattr(tensor, "ndim").cast<size_t>();
}
py::object _transpose_cpp(py::handle inp_hdl, py::handle args) {
py::tuple args_tup = py::reinterpret_borrow<py::tuple>(args.ptr());
if (fast_ndim(inp_hdl) == 0) {
if (args_tup.size() != 0) {
throw py::index_error(
"transpose for scalar does not accept additional args");
}
return getattr(inp_hdl, "to")(getattr(inp_hdl, "device"));
}
std::vector<int32_t> pattern;
if (!args_tup.size()) {
size_t ndim = getattr(inp_hdl, "ndim").cast<size_t>();
for (size_t i = 0; i < ndim; ++i) {
pattern.push_back(ndim - i - 1);
}
} else {
py::list lis;
if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) ||
is_tensor_or_symbolvar(args_tup[0].ptr()))) {
lis = py::reinterpret_steal<py::list>(PySequence_List(args_tup[0].ptr()));
} else {
lis = py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr()));
}
for (size_t i = 0; i < lis.size(); ++i) {
if (PyLong_Check(lis[i].ptr())) {
pattern.push_back(lis[i].cast<int32_t>());
} else {
if (lis[i].cast<std::string>() == "x") {
pattern.push_back(-1);
}
}
}
}
std::shared_ptr<OpDef> op = Dimshuffle::make(pattern);
std::vector<PyObject*> p;
p.resize(2);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = inp_hdl.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
}
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _make_shape_tuple(py::handle(args[0])).release().ptr();
......@@ -842,4 +893,11 @@ PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _transpose_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
} // namespace mgb::imperative::python
......@@ -14,4 +14,6 @@ PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs);
} // namespace mgb::imperative::python
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册