From 28c6ebfef3bf65ef29fbd0df71b13a76fbdcf989 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 21 Nov 2022 11:13:28 +0800 Subject: [PATCH] perf(imperative): speed up subtensor GitOrigin-RevId: c3d94bfde8f4d3c7e2efc46af7e17a255ee01785 --- imperative/python/src/grad_override.cpp | 32 ++- imperative/python/src/tensor_utils.cpp | 150 +++++++++++- imperative/src/impl/ops/broadcast.cpp | 3 + imperative/src/impl/ops/indexing.cpp | 11 +- imperative/src/impl/ops/specializations.cpp | 1 - imperative/src/impl/ops/subtensor.cpp | 218 ++++++++++++++++++ .../src/impl/transformations/format.cpp | 43 +++- imperative/tablegen/generated/hash.txt | 10 +- imperative/tablegen/generated/opdef.cpp.inl | 3 + imperative/tablegen/generated/opdef.cpy.inl | 28 ++- imperative/tablegen/generated/opdef.h.inl | 3 +- imperative/tablegen/generated/opdef.py.inl | 5 +- src/core/include/megbrain/ir/ops.td | 10 +- 13 files changed, 487 insertions(+), 30 deletions(-) create mode 100644 imperative/src/impl/ops/subtensor.cpp diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index 8e9b5db12..9579dca3f 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -325,6 +325,35 @@ std::optional subtensor_grad_rule( inputs2.push_back(inputs[i]); } } + CompNodeValue::ref_t device = inputs[0].device(); + auto get_subtensor_index = [&](int idx) { + HostTensorStorage storage(*device); + storage.ensure_size(dtype::Int32().size()); + auto* ptr = reinterpret_cast(storage.ptr()); + ptr[0] = idx; + return imperative::apply( + CreateTensor( + CreateTensor::Unique, *device, dtype::Int32(), ValueShape({1})), + HostStorage::make(storage))[0]; + }; + auto slice_items = subtensor.slice_items; + auto items = subtensor.items; + for (int i = 0; i < slice_items.size(); i++) { + auto&& [axis, b_flag, e_flag, s_flag, idx_flag] = items[i]; + auto&& [b_val, e_val, s_val, ax_val] = slice_items[i]; + if (b_flag) { + inputs2.push_back(get_subtensor_index(b_val)); + }; + if (e_flag) { + inputs2.push_back(get_subtensor_index(e_val)); + }; + if (s_flag) { + inputs2.push_back(get_subtensor_index(s_val)); + }; + if (idx_flag) { + inputs2.push_back(get_subtensor_index(ax_val)); + }; + }; auto maker = CustomGradMaker(backward, inputs.size()); maker.output_size(1).output_captured(0, false); maker.backward([inputs = std::move(inputs2), @@ -647,8 +676,9 @@ std::optional warp_affine_grad_rule( ret[1] = imperative::apply(*grad_op, args_)[0]; std::vector> items; + std::vector> slice_items; items.push_back(std::make_tuple(1, true, true, false, false)); - auto&& subtensor = Subtensor::make(items); + auto&& subtensor = Subtensor::make(items, slice_items); CompNodeValue::ref_t device = inputs[0].device(); DTypeValue::ref_t dtype = inputs[0].dtype(); diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 19a2f3f7d..7c7d398b0 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -781,14 +781,8 @@ std::pair get_ndim_safe(py::handle tensor) { } } -py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) { +py::tuple _unpack_indexes(py::handle inp_hdl, py::tuple tuple_val) { py::object inp = py::reinterpret_borrow(inp_hdl); - py::tuple tuple_val; - if (py::isinstance(idx_hdl)) { - tuple_val = py::reinterpret_borrow(idx_hdl); - } else { - tuple_val = py::make_tuple(idx_hdl); - } bool use_subtensor = true; bool need_remove_ellipsis = false; @@ -939,6 +933,20 @@ bool enable_fastpath(py::handle inp) { return true; } +bool subtensor_fastpath(py::handle inp_hdl, py::tuple tuple_val) { + bool use_fastpath = true; + for (size_t i = 0; i < tuple_val.size(); ++i) { + PyObject* obj = tuple_val[i].ptr(); + if ((!is_scalar(obj) && !PySlice_Check(obj) && obj != Py_Ellipsis && + obj != Py_None) || + (PyObject_TypeCheck(obj, py_varnode_type))) { + use_fastpath = false; + break; + } + } + return use_fastpath && enable_fastpath(inp_hdl); +} + py::object _broadcast_cpp(py::handle input, py::handle args) { py::object shape = _expand_args(args); py::list dims; @@ -1128,16 +1136,129 @@ py::object _adaptive_pool2d_cpp( return ret[0]; } +py::object _fastpath_getitem_cpp(py::handle inp_hdl, py::tuple tuple_val) { + py::object inp = py::reinterpret_borrow(inp_hdl); + int ax = 0; + bool use_ellipsis = false; + size_t special_dim = 0; + + for (size_t i = 0; i < tuple_val.size(); ++i) { + PyObject* obj = tuple_val[i].ptr(); + if (obj == Py_Ellipsis) { + use_ellipsis = true; + for (size_t j = i + 1; j < tuple_val.size(); j++) { + PyObject* obj_last = tuple_val[j].ptr(); + if (obj_last == Py_Ellipsis) { + throw py::index_error("only one ellipsis is allowed."); + } + } + } + if (obj != Py_None && obj != Py_Ellipsis && obj != Py_True && obj != Py_False) { + special_dim++; + } + } + + size_t ndim = 0; + try { + ndim = getattr(inp_hdl, "ndim").cast(); + } catch (py::error_already_set& err) { + if (use_ellipsis) { + throw py::index_error( + "does not support Ellipsis when tensor's ndim is unknown."); + }; + } + + std::vector> cpp_items; + std::vector> slice_items; + std::vector expand_items; + + for (size_t i = 0; i < tuple_val.size(); ++i) { + py::object t = tuple_val[i]; + if (t.ptr() == Py_Ellipsis) { + ax += ndim - special_dim; + } else if (PySlice_Check(t.ptr())) { + PySliceObject* s = (PySliceObject*)t.ptr(); + std::vector items; + std::vector idx_items; + auto push = [&](PyObject* v, int default_value) { + if (v == Py_None) { + items.push_back(default_value); + idx_items.push_back(false); + } else { + auto obj = py::reinterpret_borrow(v); + items.push_back(obj.cast()); + idx_items.push_back(true); + } + }; + push(s->start, INT_MIN); + push(s->stop, INT_MAX); + push(s->step, INT_MAX); + if (idx_items[0] || idx_items[1] || idx_items[2]) { + cpp_items.push_back( + {ax, idx_items[0], idx_items[1], idx_items[2], false}); + slice_items.push_back({items[0], items[1], items[2], INT_MAX}); + } + ax += 1; + } else if (PyLong_Check(t.ptr()) && !PyBool_Check(t.ptr())) { + cpp_items.push_back({ax, false, false, false, true}); + slice_items.push_back({INT_MIN, INT_MAX, INT_MAX, t.cast()}); + ax += 1; + } else if (PyBool_Check(t.ptr())) { + expand_items.push_back(ax); + } else if (t.ptr() == Py_None) { + expand_items.push_back(ax); + ax += 1; + } else if (is_scalar(t.ptr())) { + cpp_items.push_back({ax, false, false, false, true}); + slice_items.push_back({INT_MIN, INT_MAX, INT_MAX, t.cast()}); + ax += 1; + } else { + throw py::value_error("fast path subtensor index not impl"); + } + } + + if (expand_items.size()) { + std::shared_ptr op = AddAxis::make(expand_items); + py::object Op = py::cast(op); + PyObject* p[2] = {Op.ptr(), inp.ptr()}; + py::tuple ret = py::reinterpret_steal(py_apply(NULL, p, 2)); + inp = ret[0]; + } + + std::shared_ptr op; + op = Subtensor::make(cpp_items, slice_items); + + std::vector p; + p.resize(2); + py::object Op = py::cast(op); + p[0] = Op.ptr(); + p[1] = inp.ptr(); + + py::tuple ret = + py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); + return ret[0]; +} + py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) { py::tuple try_res = _try_cond_take(inp_hdl, idx_hdl); if (try_res.size() == 2) { return try_res[0]; } - py::tuple up = _unpack_indexes(inp_hdl, idx_hdl); + py::tuple tuple_val; + if (py::isinstance(idx_hdl)) { + tuple_val = py::reinterpret_borrow(idx_hdl); + } else { + tuple_val = py::make_tuple(idx_hdl); + } + if (subtensor_fastpath(inp_hdl, tuple_val)) { + return _fastpath_getitem_cpp(inp_hdl, tuple_val); + } + py::tuple up = _unpack_indexes(inp_hdl, tuple_val); py::object tensor = py::reinterpret_borrow(up[0]); py::list tensors = py::reinterpret_borrow(up[1]); py::list py_items = py::reinterpret_borrow(up[2]); std::vector> cpp_items; + std::vector> slice_items; for (size_t i = 0; i < py_items.size(); ++i) { py::list item = py::reinterpret_borrow(py_items[i]); cpp_items.push_back( @@ -1146,7 +1267,7 @@ py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) { } std::shared_ptr op; if (up[3].cast()) { - op = Subtensor::make(cpp_items); + op = Subtensor::make(cpp_items, slice_items); } else { op = IndexingMultiAxisVec::make(cpp_items); } @@ -1170,11 +1291,18 @@ py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_h val = _Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device")); } - py::tuple up = _unpack_indexes(inp_hdl, idx_hdl); + py::tuple tuple_val; + if (py::isinstance(idx_hdl)) { + tuple_val = py::reinterpret_borrow(idx_hdl); + } else { + tuple_val = py::make_tuple(idx_hdl); + } + py::tuple up = _unpack_indexes(inp_hdl, tuple_val); py::object tensor = py::reinterpret_borrow(up[0]); py::list tensors = py::reinterpret_borrow(up[1]); py::list py_items = py::reinterpret_borrow(up[2]); std::vector> cpp_items; + std::vector> slice_items; for (size_t i = 0; i < py_items.size(); ++i) { py::list item = py::reinterpret_borrow(py_items[i]); cpp_items.push_back( @@ -1183,7 +1311,7 @@ py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_h } std::shared_ptr op, set_op; if (up[3].cast()) { - op = Subtensor::make(cpp_items); + op = Subtensor::make(cpp_items, slice_items); } else { op = IndexingMultiAxisVec::make(cpp_items); } diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp index d2cb2c4b3..f55ac2568 100644 --- a/imperative/src/impl/ops/broadcast.cpp +++ b/imperative/src/impl/ops/broadcast.cpp @@ -208,6 +208,9 @@ SmallVector apply_on_physical_tensor( cg::copy_tensor_value_to_shape( tshp, tshp_nd->get_value().proxy_to_default_cpu()); } + if (tshp.is_empty()) { + return {Tensor::make(TensorLayout(tshp, src->dtype()), src->comp_node())}; + } TensorLayout tlayout = slayout.broadcast(tshp); // memory forward return {Tensor::make(src->blob(), src->offset(), tlayout)}; diff --git a/imperative/src/impl/ops/indexing.cpp b/imperative/src/impl/ops/indexing.cpp index 6187654e6..9036faa02 100644 --- a/imperative/src/impl/ops/indexing.cpp +++ b/imperative/src/impl/ops/indexing.cpp @@ -103,8 +103,7 @@ std::tuple, bool> infer_output_attrs_fallible( if (!src.ndim) { return {{{{{}, src.dtype}, comp_node}}, false}; } - mgb_assert(src.is_contiguous(), "src should be contiguous"); - return {{{src, comp_node}}, true}; + return {{{{src, src.dtype}, comp_node}}, true}; } auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { @@ -138,10 +137,18 @@ SmallVector apply_on_physical_tensor( dnn_op.exec_with_ws(out, index, sub); return {out}; } +SmallVector get_input_layout_constraint( + const OpDef& def, const SmallVector& inputs) { + SmallVector layout_checker(inputs.size()); + layout_checker[0] = layout_checker[1] = layout_checker[2] = + [](const TensorLayout& layout) { return layout.is_contiguous(); }; + return layout_checker; +} OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) .infer_output_attrs_fallible(infer_output_attrs_fallible) .apply_on_var_node(apply_on_var_node) + .get_input_layout_constraint(get_input_layout_constraint) .apply_on_physical_tensor(apply_on_physical_tensor) .fallback(); } // namespace indexing_set_one_hot diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index 4c4e68517..acbc5d716 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -542,7 +542,6 @@ auto get_index( OP_TRAIT_REG(NAME, NAME).apply_on_var_node(apply_on_var_node).fallback(); \ } -FANCY_INDEXING_IMPL(Subtensor, 1) FANCY_INDEXING_IMPL(SetSubtensor, 2) FANCY_INDEXING_IMPL(IncrSubtensor, 2) FANCY_INDEXING_IMPL(IndexingMultiAxisVec, 1) diff --git a/imperative/src/impl/ops/subtensor.cpp b/imperative/src/impl/ops/subtensor.cpp new file mode 100644 index 000000000..c1680e648 --- /dev/null +++ b/imperative/src/impl/ops/subtensor.cpp @@ -0,0 +1,218 @@ +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/imperative/proxy_graph_detail.h" +#include "megbrain/opr/basic_arith.h" +#include "megbrain/opr/internal/indexing_helper.h" +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" +#include "megbrain/opr/io.h" +#include "megbrain/opr/tensor_manip.h" +#include "megbrain/opr/utility.h" +#include "megbrain/tensor.h" + +#include "../algo_chooser.h" +#include "../blob_manager_impl.h" +#include "../dnn_op_helper.h" +#include "../op_trait.h" + +using namespace mgb::opr::indexing; +namespace mgb::imperative { + +namespace { +namespace subtensor { + +auto get_index( + const VarNodeArray& inputs, + const std::vector>& mask, + const std::vector>& slice) { + size_t length = mask.size(); + auto graph = inputs[0]->owner_graph(); + auto comp_node = inputs[0]->comp_node(); + opr::Subtensor::IndexDesc ret(length); + auto immutable_node = [&](int val) { + DTypeScalar scalar = DTypeScalar(static_cast(val)); + return opr::ImmutableTensor::make(*graph, scalar, {comp_node}); + }; + for (size_t i = 0; i < length; ++i) { + auto&& [axis, b_flag, e_flag, s_flag, idx_flag] = mask[i]; + auto&& [b_val, e_val, s_val, ax_val] = slice[i]; + ret[i].axis = axis; + if (idx_flag) { + ret[i].idx = immutable_node(ax_val); + } else { + if (b_flag) { + ret[i].begin = immutable_node(b_val); + } + if (e_flag) { + ret[i].end = immutable_node(e_val); + } + if (s_flag) { + ret[i].step = immutable_node(s_val); + } + } + } + return ret; +} + +auto origin_get_index( + const VarNodeArray& inputs, size_t vidx, + const std::vector>& mask) { + size_t length = mask.size(); + opr::Subtensor::IndexDesc ret(length); + for (size_t i = 0; i < length; ++i) { + auto&& [axis, begin, end, step, idx] = mask[i]; + ret[i].axis = axis; + if (idx) { + ret[i].idx = inputs[vidx++]; + } else { + mgb_assert(begin || end || step); + if (begin) + ret[i].begin = inputs[vidx++]; + if (end) + ret[i].end = inputs[vidx++]; + if (step) + ret[i].step = inputs[vidx++]; + } + } + mgb_assert(vidx == inputs.size()); + return ret; +} + +TensorLayout deduce_layout( + TensorLayout src, std::vector> items, + std::vector> slice_items) { + auto mod_size = [](int v, int size_ax) -> int { + if (size_ax == 0) + return 0; + return v < 0 ? v + size_ax : v; + }; +#define CHECK(cond) \ + mgb_assert(cond, "index out of bound: layout=%s", src.to_string().c_str()) + + for (int i = items.size() - 1; i >= 0; i--) { + auto&& [axis, b_flag, e_flag, s_flag, idx_flag] = items[i]; + auto&& [b_val, e_val, s_val, ax_val] = slice_items[i]; + int shape_axis = src.shape[axis]; + int slice_step = s_val == INT_MAX ? 1 : s_val; + int slice_start = b_val == INT_MIN ? 0 : b_val; + int slice_stop = e_val == INT_MAX ? shape_axis : e_val; + if (slice_step > 0) { + slice_start = mod_size(slice_start, shape_axis); + slice_stop = mod_size(slice_stop, shape_axis); + slice_stop = std::min(slice_stop, shape_axis); + slice_start = std::min(slice_start, slice_stop); + CHECK(slice_start >= 0 && slice_stop >= slice_start && + slice_stop <= shape_axis); + } else { + slice_start = s_val == INT_MIN ? shape_axis - 1 : b_val; + slice_start = mod_size(slice_start, shape_axis); + slice_stop = e_val == INT_MAX ? -1 : mod_size(e_val, shape_axis); + slice_start = std::min(slice_start, std::max(shape_axis - 1, 0)); + slice_stop = std::min(slice_stop, slice_start); + CHECK(slice_step < 0 && slice_start >= 0 && slice_stop <= slice_start && + slice_start < shape_axis && slice_stop >= -1); + } + int abs_step = std::abs(slice_step); + if (axis < 0) { + axis = axis + src.ndim; + }; + + if (idx_flag == true) { + if (src.ndim == 1) { + src.shape[0] = 1; + } else { + src.remove_axis_inplace(axis); + } + } else { + src.shape[axis] = + (std::abs(slice_stop - slice_start) + abs_step - 1) / abs_step; + src.stride[axis] *= slice_step; + } + } + return src; +} + +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = static_cast(def); + OperatorNodeConfig config{op.make_name()}; + if (inputs.size() > 1) { + return opr::Subtensor::make( + inputs[0], origin_get_index(inputs, 1, op.items), config); + } else { + return opr::Subtensor::make( + inputs[0], get_index(inputs, op.items, op.slice_items), config); + } +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + if (inputs.size() >= 2) { + return proxy_graph_detail::infer_output_attrs_fallible(def, inputs); + } + auto&& inp = inputs[0]; + auto& inp_cn = inp.comp_node; + if (inp.layout.ndim == 0) { + return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}}}, false}; + } + auto&& op = static_cast(def); + + auto items = op.items; + auto slice_itmes = op.slice_items; + TensorLayout out_layout = deduce_layout(inp.layout, items, slice_itmes); + + return {{{out_layout, inp_cn, {}}}, true}; +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + CompNode cn = inputs[0]->comp_node(); + auto&& layout = inputs[0]->layout(); + auto&& op = static_cast(def); + + if (inputs.size() > 1) { + return proxy_graph_detail::apply_on_physical_tensor( + def, inputs, output_descs, validated); + } + auto&& src = inputs[0]; + auto slice_items = op.slice_items; + auto items = op.items; + TensorLayout res_layout = deduce_layout(layout, items, slice_items); + if (res_layout.is_empty()) { + return {Tensor::make(res_layout, cn)}; + } + size_t offset = 0; + size_t dtype_size = layout.dtype.size(); + TensorPtr tensor = src; + for (int i = items.size() - 1; i >= 0; i--) { + auto&& [axis, b_flag, e_flag, s_flag, idx_flag] = items[i]; + auto&& [b_val, e_val, s_val, ax_val] = slice_items[i]; + int start = b_val; + if (idx_flag) { + ax_val = ax_val < 0 ? layout.shape[axis] + ax_val : ax_val; + offset += ax_val * layout.stride[axis] * dtype_size; + } else { + start = std::max(start, 0); + offset += start * layout.stride[axis] * dtype_size; + } + } + + // memory forward + return {Tensor::make(src->blob(), src->offset() + offset, res_layout)}; +} + +SmallVector get_input_layout_constraint( + const OpDef& def, const SmallVector& inputs) { + SmallVector layout_checker(inputs.size()); + return layout_checker; +} + +OP_TRAIT_REG(Subtensor, Subtensor, opr::Subtensor) + .apply_on_var_node(apply_on_var_node) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_physical_tensor(apply_on_physical_tensor) + .get_input_layout_constraint(get_input_layout_constraint) + .fallback(); + +} // namespace subtensor +} // namespace + +} // namespace mgb::imperative \ No newline at end of file diff --git a/imperative/src/impl/transformations/format.cpp b/imperative/src/impl/transformations/format.cpp index f7e4f4a49..66c65730b 100644 --- a/imperative/src/impl/transformations/format.cpp +++ b/imperative/src/impl/transformations/format.cpp @@ -306,6 +306,19 @@ inline bool is_reduce_ndim_idx_items( return false; } +inline bool is_subtensor_reduce_ndim( + const std::vector>& items, + const std::vector> slice_items) { + for (auto i = 0; i < items.size(); ++i) { + auto&& [axis, begin, end, step, idx] = items[i]; + if (idx) { + auto&& [b_val, e_val, s_val, ax_val] = slice_items[i]; + return ax_val != INT_MAX; + } + } + return false; +} + inline auto convert_nchw2nhwc_idx_items( const std::vector>& items) { auto nhwc_items = items; @@ -326,6 +339,34 @@ ValueRefList subtensor_rule( const FormatTransformation& t) { mgb_assert(inputs.size() >= 1); auto& src = inputs[0].cast(t.value_type()); + bool is_reduce_ndim = false; + if (inputs.size() > 1) { + is_reduce_ndim = is_reduce_ndim_idx_items( + op.items, {&inputs[1], &inputs[inputs.size() - 1]}); + } else { + is_reduce_ndim = is_subtensor_reduce_ndim(op.items, op.slice_items); + } + if (!is_reduce_ndim) { + // only support NHWC2NCHW convert, otherwise maintain src's format + if (!(auto_convert && src.format() == FT::NHWC)) { + return {t.wrap_output( + imperative::apply(op, t.unwrap_inputs(inputs))[0], src.format())}; + } + auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); + auto outputs = imperative::apply( + *T::make(std::move(nhwc_items), op.slice_items, op.scope()), + t.unwrap_inputs(inputs)); + return t.wrap_outputs(outputs, FT::NHWC); + } + return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); +} + +template +ValueRefList indexing_rule( + const T& op, Span& inputs, const bool& auto_convert, + const FormatTransformation& t) { + mgb_assert(inputs.size() >= 1); + auto& src = inputs[0].cast(t.value_type()); bool is_reduce_ndim = is_reduce_ndim_idx_items( op.items, {&inputs[1], &inputs[inputs.size() - 1]}); if (!is_reduce_ndim) { @@ -597,7 +638,7 @@ struct FormatRuleRegistry { register_format_rule(reshape_rule); register_format_rule(broadcast_rule); register_format_rule(subtensor_rule); - register_format_rule(subtensor_rule); + register_format_rule(indexing_rule); register_format_rule(setsubtensor_rule); register_format_rule(setsubtensor_rule); register_format_rule(elemwise_rule); diff --git a/imperative/tablegen/generated/hash.txt b/imperative/tablegen/generated/hash.txt index 511a1e433..1c21d4f50 100644 --- a/imperative/tablegen/generated/hash.txt +++ b/imperative/tablegen/generated/hash.txt @@ -1,7 +1,7 @@ 8dd504f360fd3d3bfb560c970b568153 ../../dnn/scripts/opr_param_defs.py -cf864561de125ab559c0035158656682 ../../src/core/include/megbrain/ir/ops.td -f27cdbb7926e0be9f5dabb8651d2e4da generated/opdef.h.inl -96817f709ee92c8e1eb7cb4168f28565 generated/opdef.cpp.inl -672668fa3ed11c27781f0fa380e6c8aa generated/opdef.py.inl -47511e3e7fed8c64a1c4fea48d79b3d1 generated/opdef.cpy.inl +6811fde221f86d1ef8de425a3c83127b ../../src/core/include/megbrain/ir/ops.td +55123da1605ef6edd79e3a2ede8aefeb generated/opdef.h.inl +6f4beb6d12cdd9ec4c4e61b6d7d35144 generated/opdef.cpp.inl +185ba3c3a0fce480ee498cef058670b2 generated/opdef.py.inl +b7ed7a638b7586709bb23dd153fb58b1 generated/opdef.cpy.inl 71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h diff --git a/imperative/tablegen/generated/opdef.cpp.inl b/imperative/tablegen/generated/opdef.cpp.inl index 6ec6e918e..80e2b3efc 100644 --- a/imperative/tablegen/generated/opdef.cpp.inl +++ b/imperative/tablegen/generated/opdef.cpp.inl @@ -6918,6 +6918,7 @@ size_t Subtensor_hash_impl(const OpDef& def_) { static_cast(op_); size_t val = mgb::hash(op_.dyn_typeinfo()); val = mgb::hash_pair_combine(val, mgb::hash(op_.items)); + val = mgb::hash_pair_combine(val, mgb::hash(op_.slice_items)); return val; } bool Subtensor_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { @@ -6926,6 +6927,7 @@ bool Subtensor_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { static_cast(a_); static_cast(b_); if (a_.items != b_.items) return false; + if (a_.slice_items != b_.slice_items) return false; return true; } std::vector> Subtensor_props_impl(const OpDef& def_) { @@ -6933,6 +6935,7 @@ std::vector> Subtensor_props_impl(const OpDe static_cast(op_); std::vector> props_; props_.emplace_back("items", "{std::vector}"); + props_.emplace_back("slice_items", "{std::vector}"); return props_; } std::string Subtensor_make_name_impl(const OpDef& def_) { diff --git a/imperative/tablegen/generated/opdef.cpy.inl b/imperative/tablegen/generated/opdef.cpy.inl index 42b837096..1f9882bee 100644 --- a/imperative/tablegen/generated/opdef.cpy.inl +++ b/imperative/tablegen/generated/opdef.cpy.inl @@ -20113,7 +20113,8 @@ PyOpDefBegin(Subtensor) // { static_cast(opdef); std::unordered_map state { - {"items", serialization::dump(opdef.items)} + {"items", serialization::dump(opdef.items)}, + {"slice_items", serialization::dump(opdef.slice_items)} }; return py::cast(state).release().ptr(); } @@ -20130,6 +20131,13 @@ PyOpDefBegin(Subtensor) // { opdef.items = serialization::load(iter->second); } } + + { + auto&& iter = state.find("slice_items"); + if (iter != state.end()) { + opdef.slice_items = serialization::load(iter->second); + } + } Py_RETURN_NONE; } static int py_init(PyObject *self, PyObject *args, PyObject *kwds); @@ -20139,9 +20147,9 @@ PyOpDefBegin(Subtensor) // { PyOpDefEnd(Subtensor) int PyOp(Subtensor)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { - static const char* kwlist[] = {"items", "scope", NULL}; - PyObject *items = NULL, *scope = NULL; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO", const_cast(kwlist), &items, &scope)) + static const char* kwlist[] = {"items", "slice_items", "scope", NULL}; + PyObject *items = NULL, *slice_items = NULL, *scope = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOO", const_cast(kwlist), &items, &slice_items, &scope)) return -1; if (items) { @@ -20153,6 +20161,15 @@ int PyOp(Subtensor)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { } CATCH_ALL(-1) } + if (slice_items) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().slice_items = + py::cast(py::handle(slice_items)); + } CATCH_ALL(-1) + } + if (scope) { try { reinterpret_cast(self)->op @@ -20165,6 +20182,7 @@ int PyOp(Subtensor)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { PyGetSetDef PyOp(Subtensor)::py_getsetters[] = { {const_cast("items"), py_get_generic(Subtensor, items), py_set_generic(Subtensor, items), const_cast("items"), NULL}, + {const_cast("slice_items"), py_get_generic(Subtensor, slice_items), py_set_generic(Subtensor, slice_items), const_cast("slice_items"), NULL}, {NULL} /* Sentinel */ }; @@ -20185,7 +20203,7 @@ PyMethodDef PyOp(Subtensor)::py_init_methoddef = { "__init__", (PyCFunction)PyOp(Subtensor)::py_init_proxy, METH_VARARGS | METH_KEYWORDS, - "__init__(self, items: list[tuple[int, bool, bool, bool, bool]] = ...) -> None\n" + "__init__(self, items: list[tuple[int, bool, bool, bool, bool]] = ..., slice_items: list[tuple[int, int, int, int]] = ...) -> None\n" }; void _init_py_Subtensor(py::module m) { diff --git a/imperative/tablegen/generated/opdef.h.inl b/imperative/tablegen/generated/opdef.h.inl index 88fe7c61a..2ad322921 100644 --- a/imperative/tablegen/generated/opdef.h.inl +++ b/imperative/tablegen/generated/opdef.h.inl @@ -1795,8 +1795,9 @@ class Subtensor : public OpDefImplBase { public: std::vector> items; + std::vector> slice_items; Subtensor() = default; - Subtensor(std::vector> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); } + Subtensor(std::vector> items_, std::vector> slice_items_, std::string scope_ = {}): items(items_), slice_items(slice_items_) { set_scope(scope_); } }; class TQT : public OpDefImplBase { diff --git a/imperative/tablegen/generated/opdef.py.inl b/imperative/tablegen/generated/opdef.py.inl index 9857bf365..1ceae2c0d 100644 --- a/imperative/tablegen/generated/opdef.py.inl +++ b/imperative/tablegen/generated/opdef.py.inl @@ -1882,9 +1882,10 @@ SplitInst py::class_, OpDef> SubtensorInst(m, "Subtensor"); SubtensorInst - .def(py::init>, std::string>(), py::arg("items"), py::arg("scope") = {}) + .def(py::init>, std::vector>, std::string>(), py::arg("items"), py::arg("slice_items"), py::arg("scope") = {}) .def(py::init<>()) - .def_readwrite("items", &Subtensor::items); + .def_readwrite("items", &Subtensor::items) + .def_readwrite("slice_items", &Subtensor::slice_items); py::class_, OpDef> TQTInst(m, "TQT"); diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 221ee8769..9cf83dd91 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -380,7 +380,15 @@ class FancyIndexingBase: MgbHashableOp { ); } -def Subtensor: FancyIndexingBase<"Subtensor">; +def Subtensor: MgbHashableOp<"Subtensor"> { + let extraArguments = (ins + MgbArrayAttr>:$items, + MgbArrayAttr>:$slice_items + ); +} + +// def Subtensor: FancyIndexingBase<"Subtensor">; def SetSubtensor: FancyIndexingBase<"SetSubtensor">; def IncrSubtensor: FancyIndexingBase<"IncrSubtensor">; def IndexingMultiAxisVec: FancyIndexingBase<"IndexingMultiAxisVec">; -- GitLab