提交 a4d473c9 编写于 作者: M Megvii Engine Team

perf(mge/functional): speed up AddAxis

GitOrigin-RevId: 92a3e1bdd3c4f0d1d68d8571cad78c1c8ea0f634
上级 3e206d89
......@@ -12,7 +12,13 @@ from typing import Iterable, Optional, Sequence, Tuple, Union
import numpy as np
from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import SymbolVar, apply, dtype_promotion, split_cpp
from ..core._imperative_rt.core2 import (
SymbolVar,
apply,
dtype_promotion,
expand_dims_cpp,
split_cpp,
)
from ..core._wrap import as_device
from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity
......@@ -959,27 +965,7 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
(1, 2)
"""
def get_axes():
try:
return [int(axis)]
except (TypeError, ValueError):
pass
return list(map(int, axis))
axis = get_axes()
try:
ndim = inp.ndim + len(axis)
axis = sorted(i + ndim if i < 0 else i for i in axis)
except ValueError:
if any([ind < 0 for ind in axis]):
raise IndexError(
"Does not support negative index when tensor's ndim is unknown"
)
axis = sorted(axis)
assert axis, "axis could not be empty"
op = builtin.AddAxis(axis=axis)
(result,) = apply(op, inp)
return result
return expand_dims_cpp(inp, axis)
def squeeze(inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None) -> Tensor:
......
......@@ -634,6 +634,7 @@ WRAP_FUNC_PY35(make_shape_tuple);
WRAP_FUNC_PY35(getitem_cpp);
WRAP_FUNC_PY35(setitem_cpp);
WRAP_FUNC_PY35(split_cpp);
WRAP_FUNC_PY35(expand_dims_cpp);
#undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
......@@ -767,6 +768,7 @@ void init_tensor(py::module m) {
MGE_PY_INTERFACE(getitem_cpp, getitem_cpp),
MGE_PY_INTERFACE(setitem_cpp, setitem_cpp),
MGE_PY_INTERFACE(split_cpp, split_cpp),
MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp),
{nullptr, nullptr, 0, nullptr}};
for (auto&& def : method_defs) {
if (def.ml_meth != nullptr) {
......
......@@ -683,6 +683,59 @@ py::object _split_cpp(
return py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
}
py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) {
std::vector<int32_t> axis;
if (is_py_sequence(axis_hdl.ptr())) {
py::list tmp_list =
py::reinterpret_steal<py::list>(PySequence_List(axis_hdl.ptr()));
for (size_t i = 0; i < tmp_list.size(); ++i) {
axis.push_back(tmp_list[i].attr("__int__")().cast<int32_t>());
}
} else {
axis.push_back(getattr(axis_hdl, "__int__")().cast<int>());
}
bool unknown_ndim = true;
size_t ndim = axis.size();
if (auto p = TensorWrapper::try_cast(inp_hdl.ptr())) {
auto&& shape = p->m_tensor->shape();
if (shape) {
unknown_ndim = false;
ndim += shape->ndim;
}
} else {
auto&& var = inp_hdl.cast<PySymbolVar*>();
auto&& mgr = var->m_node->owner_graph()->static_infer_manager();
auto&& shape = mgr.infer_shape_fallible(var->m_node);
if (shape) {
unknown_ndim = false;
ndim += shape->ndim;
}
}
for (size_t i = 0; i < axis.size(); ++i) {
if (axis[i] < 0) {
if (unknown_ndim) {
throw py::index_error(
"Does not support negative index when tensor's ndim is "
"unknown");
}
axis[i] += ndim;
}
}
if (!axis.size()) {
throw py::index_error("axis could not be empty");
}
std::sort(axis.begin(), axis.end());
std::shared_ptr<OpDef> op = AddAxis::make(axis = axis);
std::vector<PyObject*> p;
p.resize(2);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = inp_hdl.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
}
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _make_shape_tuple(py::handle(args[0])).release().ptr();
......@@ -716,4 +769,13 @@ PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _expand_dims_cpp(py::handle(args[0]), py::handle(args[1]))
.release()
.ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
} // namespace mgb::imperative::python
......@@ -10,4 +10,6 @@ PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs);
} // namespace mgb::imperative::python
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册