diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 1937ed0b0c7ef7a35941350535e62e8b9dfa7508..20f047c78bb39c3d436f83adf49a464efbc1d0ba 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -12,7 +12,7 @@ from typing import Iterable, Optional, Sequence, Tuple, Union import numpy as np from ..core._imperative_rt import CompNode -from ..core._imperative_rt.core2 import SymbolVar, apply, dtype_promotion +from ..core._imperative_rt.core2 import SymbolVar, apply, dtype_promotion, split_cpp from ..core._wrap import as_device from ..core.ops import builtin from ..core.ops.builtin import Copy, Identity @@ -477,50 +477,8 @@ def split(inp, nsplits_or_sections, axis=0): [(4, 20), (3, 20), (3, 20)] [(10, 6), (10, 11), (10, 3)] """ - ndim = len(inp.shape) - if axis >= ndim: - raise ValueError("Invalid axis {}".format(axis)) - Ntotal = inp.shape[axis] - - if isinstance(nsplits_or_sections, Sequence): - Nsections = len(nsplits_or_sections) + 1 - is_array = True - else: - Nsections = int(nsplits_or_sections) - is_array = False - - if is_array: - partitions = [] - div_points = [0] + list(nsplits_or_sections) + [Ntotal] - for i in range(1, len(div_points)): - if div_points[i - 1] > div_points[i]: - raise ValueError( - "Invalid nsplits_or_secions: {}".format(nsplits_or_sections) - ) - partitions.append(div_points[i] - div_points[i - 1]) - else: # scalar - if Nsections <= 0: - raise ValueError("Number sections must be larger than 0") - if Nsections > Ntotal: - raise ValueError( - "The size {} at dim {} cannot be split into {} sections".format( - Ntotal, axis, Nsections - ) - ) - partitions = [] - for i in range(Nsections): - section_size = (Ntotal + Nsections - i - 1) // Nsections - partitions.append(section_size) - - partitions = [ - part - if isinstance(part, (SymbolVar, Tensor)) - else Const(part, dtype="int32", device=inp.device)(inp)[0] - for part in partitions - ] - op = builtin.Split(axis=axis) - return apply(op, inp, *partitions) + return split_cpp(inp, nsplits_or_sections, axis) def _get_idx(index, axis): diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 7847ccff2d2c40e18fefcf0e8ba74c794215fc03..6cc1afc208f154515429f85196bfaaea113c28d6 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -633,6 +633,7 @@ WRAP_FUNC_PY35(get_device); WRAP_FUNC_PY35(make_shape_tuple); WRAP_FUNC_PY35(getitem_cpp); WRAP_FUNC_PY35(setitem_cpp); +WRAP_FUNC_PY35(split_cpp); #undef WRAP_FUNC_PY35 #define MGE_PY_INTERFACE(NAME, FUNC) \ { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } @@ -765,6 +766,7 @@ void init_tensor(py::module m) { MGE_PY_INTERFACE(make_shape_tuple, make_shape_tuple), MGE_PY_INTERFACE(getitem_cpp, getitem_cpp), MGE_PY_INTERFACE(setitem_cpp, setitem_cpp), + MGE_PY_INTERFACE(split_cpp, split_cpp), {nullptr, nullptr, 0, nullptr}}; for (auto&& def : method_defs) { if (def.ml_meth != nullptr) { diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 6f5687cceedf6143fd1d2bf34528addee40b90d4..d6a89f3f6e42bea370cd0eda3650883be9c1c3e1 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -603,6 +603,86 @@ py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_h return res; } +bool is_tensor_or_symbolvar(py::handle arg) { + return bool(TensorWrapper::try_cast(arg.ptr())) || py::isinstance(arg); +} + +bool is_py_sequence(py::handle arg) { + if (PyArray_Check(arg.ptr()) || TensorWrapper::try_cast(arg.ptr()) || + py::isinstance(arg)) { + return false; + } + return PySequence_Check(arg.ptr()); +} + +py::object _split_cpp( + py::handle inp_hdl, py::handle nsplits_or_sections_hdl, py::handle axis_hdl) { + py::object shape_obj = getattr(inp_hdl, "shape"); + py::object n_total = shape_obj[axis_hdl]; + int ndim = shape_obj.attr("__len__")().cast(); + int axis = axis_hdl.cast(); + if (axis >= ndim) { + throw py::value_error("Invalid axis " + std::to_string(axis)); + } + int n_sections; + bool is_array; + if (is_py_sequence(nsplits_or_sections_hdl)) { + n_sections = PySequence_Length(nsplits_or_sections_hdl.ptr()) + 1; + is_array = true; + } else { + n_sections = getattr(nsplits_or_sections_hdl, "__int__")().cast(); + is_array = false; + } + py::list partitions; + std::shared_ptr op; + std::vector p; + if (is_array) { + py::list div_points; + py::list sections = py::reinterpret_borrow(nsplits_or_sections_hdl); + div_points.append(0); + for (size_t i = 0; i < sections.size(); ++i) { + div_points.append(sections[i]); + } + div_points.append(n_total); + for (size_t i = 1; i < div_points.size(); ++i) { + if (div_points[i - 1] > div_points[i]) { + throw py::value_error( + "Invalid nsplits_or_secions: " + + repr(nsplits_or_sections_hdl).cast()); + } + py::object pos = div_points[i] - div_points[i - 1]; + if (is_tensor_or_symbolvar(pos)) { + partitions.append(pos); + } else { + partitions.append( + _Const(pos, py::cast((mgb::DType)dtype::Int32()), + getattr(inp_hdl, "device"), inp_hdl)); + } + } + op = Split::make(axis, 0); + p.resize(partitions.size() + 2); + for (size_t i = 0; i < partitions.size(); ++i) { + p[i + 2] = partitions[i].ptr(); + } + } else { + if (n_sections <= 0) { + throw py::value_error("Number sections must be larger than 0"); + } + if (py::int_(n_sections) > n_total) { + throw py::value_error( + "The size " + repr(n_total).cast() + " at dim " + + std::to_string(axis) + " cannot be split into " + + std::to_string(n_sections) + " sections"); + } + op = Split::make(axis, n_sections); + p.resize(2); + } + py::object Op = py::cast(op); + p[0] = Op.ptr(); + p[1] = inp_hdl.ptr(); + return py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); +} + PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { try { return _make_shape_tuple(py::handle(args[0])).release().ptr(); @@ -627,4 +707,13 @@ PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) { PYEXT17_TRANSLATE_EXC_RET(nullptr) } +PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs) { + try { + return _split_cpp(py::handle(args[0]), py::handle(args[1]), py::handle(args[2])) + .release() + .ptr(); + } + PYEXT17_TRANSLATE_EXC_RET(nullptr) +} + } // namespace mgb::imperative::python diff --git a/imperative/python/src/tensor_utils.h b/imperative/python/src/tensor_utils.h index 420b94bd34a25a202de233b3504871a2bdbbd639..cc35ec41ccc7a6ac0ce5911c2a4641769583c672 100644 --- a/imperative/python/src/tensor_utils.h +++ b/imperative/python/src/tensor_utils.h @@ -8,4 +8,6 @@ PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs); PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs); +PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs); + } // namespace mgb::imperative::python \ No newline at end of file diff --git a/imperative/src/impl/ops/tensor_manip.cpp b/imperative/src/impl/ops/tensor_manip.cpp index 7dac558b8b27329c93efafda45540f0915b9db08..c24b06b7556ed3a5c8f41ffdfc6cf48ff8b30776 100644 --- a/imperative/src/impl/ops/tensor_manip.cpp +++ b/imperative/src/impl/ops/tensor_manip.cpp @@ -285,7 +285,7 @@ std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { opt.method == Options::Method::SPECIFY, "only Split with SPECIFY output shapes is supported"); mgb_assert(opt.partition.size() == opt.nr_part); - return Split::make(axis); + return Split::make(axis, 0); } auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { @@ -293,13 +293,18 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& sp = static_cast(def); OperatorNodeConfig config{sp.make_name()}; opr::Split::Options opt; - opt.axis = sp.axis; - opt.method = Options::Method::SPECIFY; - mgb_assert(inputs.size() > 1); - opt.nr_part = inputs.size() - 1; - opt.partition.resize(opt.nr_part); - for (size_t i = 1; i < inputs.size(); ++i) - opt.partition[i - 1] = inputs[i]; + if (sp.nsections) { + opt = Options::make_average(sp.axis, sp.nsections); + opt.method = Options::Method::CALL_BACK; + } else { + opt.axis = sp.axis; + opt.method = Options::Method::SPECIFY; + mgb_assert(inputs.size() > 1); + opt.nr_part = inputs.size() - 1; + opt.partition.resize(opt.nr_part); + for (size_t i = 1; i < inputs.size(); ++i) + opt.partition[i - 1] = inputs[i]; + } return opr::Split::make(inputs[0], opt, config); } diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index d90bdcadf82b3e577fba9e4df8643f5bd452122d..68432bb7f8e354f3ceb6263c2602057fa97f2841 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -426,7 +426,8 @@ def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>; def Split: MgbHashableOp<"Split", [EmptyParam]> { let extraArguments = (ins - MgbI32Attr:$axis + MgbI32Attr:$axis, + MgbI32Attr:$nsections ); } diff --git a/src/opr/include/megbrain/opr/tensor_manip.h b/src/opr/include/megbrain/opr/tensor_manip.h index 01b2fa99793c1735b85893012d8b9adf61cd24a1..d558ee3e6395d9d6530ea65c68cffa34048466a6 100644 --- a/src/opr/include/megbrain/opr/tensor_manip.h +++ b/src/opr/include/megbrain/opr/tensor_manip.h @@ -422,7 +422,7 @@ public: /*! * \brief make split option by splitting into average parts */ - static Options make_average(int axis, size_t nr_part); + MGE_WIN_DECLSPEC_FUC static Options make_average(int axis, size_t nr_part); static Options make_partition(int axis, const SymbolVarArray& partition); static Options make_partition(