提交 3e206d89 编写于 作者: M Megvii Engine Team

perf(mge/functional): speed up Split

GitOrigin-RevId: 43550a0706a2794421de56067a11864c10b85c67
上级 730ddc2d
......@@ -12,7 +12,7 @@ from typing import Iterable, Optional, Sequence, Tuple, Union
import numpy as np
from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import SymbolVar, apply, dtype_promotion
from ..core._imperative_rt.core2 import SymbolVar, apply, dtype_promotion, split_cpp
from ..core._wrap import as_device
from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity
......@@ -477,50 +477,8 @@ def split(inp, nsplits_or_sections, axis=0):
[(4, 20), (3, 20), (3, 20)]
[(10, 6), (10, 11), (10, 3)]
"""
ndim = len(inp.shape)
if axis >= ndim:
raise ValueError("Invalid axis {}".format(axis))
Ntotal = inp.shape[axis]
if isinstance(nsplits_or_sections, Sequence):
Nsections = len(nsplits_or_sections) + 1
is_array = True
else:
Nsections = int(nsplits_or_sections)
is_array = False
if is_array:
partitions = []
div_points = [0] + list(nsplits_or_sections) + [Ntotal]
for i in range(1, len(div_points)):
if div_points[i - 1] > div_points[i]:
raise ValueError(
"Invalid nsplits_or_secions: {}".format(nsplits_or_sections)
)
partitions.append(div_points[i] - div_points[i - 1])
else: # scalar
if Nsections <= 0:
raise ValueError("Number sections must be larger than 0")
if Nsections > Ntotal:
raise ValueError(
"The size {} at dim {} cannot be split into {} sections".format(
Ntotal, axis, Nsections
)
)
partitions = []
for i in range(Nsections):
section_size = (Ntotal + Nsections - i - 1) // Nsections
partitions.append(section_size)
partitions = [
part
if isinstance(part, (SymbolVar, Tensor))
else Const(part, dtype="int32", device=inp.device)(inp)[0]
for part in partitions
]
op = builtin.Split(axis=axis)
return apply(op, inp, *partitions)
return split_cpp(inp, nsplits_or_sections, axis)
def _get_idx(index, axis):
......
......@@ -633,6 +633,7 @@ WRAP_FUNC_PY35(get_device);
WRAP_FUNC_PY35(make_shape_tuple);
WRAP_FUNC_PY35(getitem_cpp);
WRAP_FUNC_PY35(setitem_cpp);
WRAP_FUNC_PY35(split_cpp);
#undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
......@@ -765,6 +766,7 @@ void init_tensor(py::module m) {
MGE_PY_INTERFACE(make_shape_tuple, make_shape_tuple),
MGE_PY_INTERFACE(getitem_cpp, getitem_cpp),
MGE_PY_INTERFACE(setitem_cpp, setitem_cpp),
MGE_PY_INTERFACE(split_cpp, split_cpp),
{nullptr, nullptr, 0, nullptr}};
for (auto&& def : method_defs) {
if (def.ml_meth != nullptr) {
......
......@@ -603,6 +603,86 @@ py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_h
return res;
}
bool is_tensor_or_symbolvar(py::handle arg) {
return bool(TensorWrapper::try_cast(arg.ptr())) || py::isinstance<PySymbolVar>(arg);
}
bool is_py_sequence(py::handle arg) {
if (PyArray_Check(arg.ptr()) || TensorWrapper::try_cast(arg.ptr()) ||
py::isinstance<PySymbolVar>(arg)) {
return false;
}
return PySequence_Check(arg.ptr());
}
py::object _split_cpp(
py::handle inp_hdl, py::handle nsplits_or_sections_hdl, py::handle axis_hdl) {
py::object shape_obj = getattr(inp_hdl, "shape");
py::object n_total = shape_obj[axis_hdl];
int ndim = shape_obj.attr("__len__")().cast<int>();
int axis = axis_hdl.cast<int>();
if (axis >= ndim) {
throw py::value_error("Invalid axis " + std::to_string(axis));
}
int n_sections;
bool is_array;
if (is_py_sequence(nsplits_or_sections_hdl)) {
n_sections = PySequence_Length(nsplits_or_sections_hdl.ptr()) + 1;
is_array = true;
} else {
n_sections = getattr(nsplits_or_sections_hdl, "__int__")().cast<int>();
is_array = false;
}
py::list partitions;
std::shared_ptr<OpDef> op;
std::vector<PyObject*> p;
if (is_array) {
py::list div_points;
py::list sections = py::reinterpret_borrow<py::object>(nsplits_or_sections_hdl);
div_points.append(0);
for (size_t i = 0; i < sections.size(); ++i) {
div_points.append(sections[i]);
}
div_points.append(n_total);
for (size_t i = 1; i < div_points.size(); ++i) {
if (div_points[i - 1] > div_points[i]) {
throw py::value_error(
"Invalid nsplits_or_secions: " +
repr(nsplits_or_sections_hdl).cast<std::string>());
}
py::object pos = div_points[i] - div_points[i - 1];
if (is_tensor_or_symbolvar(pos)) {
partitions.append(pos);
} else {
partitions.append(
_Const(pos, py::cast((mgb::DType)dtype::Int32()),
getattr(inp_hdl, "device"), inp_hdl));
}
}
op = Split::make(axis, 0);
p.resize(partitions.size() + 2);
for (size_t i = 0; i < partitions.size(); ++i) {
p[i + 2] = partitions[i].ptr();
}
} else {
if (n_sections <= 0) {
throw py::value_error("Number sections must be larger than 0");
}
if (py::int_(n_sections) > n_total) {
throw py::value_error(
"The size " + repr(n_total).cast<std::string>() + " at dim " +
std::to_string(axis) + " cannot be split into " +
std::to_string(n_sections) + " sections");
}
op = Split::make(axis, n_sections);
p.resize(2);
}
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = inp_hdl.ptr();
return py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
}
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _make_shape_tuple(py::handle(args[0])).release().ptr();
......@@ -627,4 +707,13 @@ PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _split_cpp(py::handle(args[0]), py::handle(args[1]), py::handle(args[2]))
.release()
.ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
} // namespace mgb::imperative::python
......@@ -8,4 +8,6 @@ PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs);
} // namespace mgb::imperative::python
\ No newline at end of file
......@@ -285,7 +285,7 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
opt.method == Options::Method::SPECIFY,
"only Split with SPECIFY output shapes is supported");
mgb_assert(opt.partition.size() == opt.nr_part);
return Split::make(axis);
return Split::make(axis, 0);
}
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
......@@ -293,13 +293,18 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& sp = static_cast<const Split&>(def);
OperatorNodeConfig config{sp.make_name()};
opr::Split::Options opt;
opt.axis = sp.axis;
opt.method = Options::Method::SPECIFY;
mgb_assert(inputs.size() > 1);
opt.nr_part = inputs.size() - 1;
opt.partition.resize(opt.nr_part);
for (size_t i = 1; i < inputs.size(); ++i)
opt.partition[i - 1] = inputs[i];
if (sp.nsections) {
opt = Options::make_average(sp.axis, sp.nsections);
opt.method = Options::Method::CALL_BACK;
} else {
opt.axis = sp.axis;
opt.method = Options::Method::SPECIFY;
mgb_assert(inputs.size() > 1);
opt.nr_part = inputs.size() - 1;
opt.partition.resize(opt.nr_part);
for (size_t i = 1; i < inputs.size(); ++i)
opt.partition[i - 1] = inputs[i];
}
return opr::Split::make(inputs[0], opt, config);
}
......
......@@ -426,7 +426,8 @@ def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>;
def Split: MgbHashableOp<"Split", [EmptyParam]> {
let extraArguments = (ins
MgbI32Attr:$axis
MgbI32Attr:$axis,
MgbI32Attr:$nsections
);
}
......
......@@ -422,7 +422,7 @@ public:
/*!
* \brief make split option by splitting into average parts
*/
static Options make_average(int axis, size_t nr_part);
MGE_WIN_DECLSPEC_FUC static Options make_average(int axis, size_t nr_part);
static Options make_partition(int axis, const SymbolVarArray& partition);
static Options make_partition(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册