提交 42996679 编写于 作者: M Megvii Engine Team

perf(imperative): speed up subtensor

GitOrigin-RevId: c3d94bfde8f4d3c7e2efc46af7e17a255ee01785
上级 c9d265c7
......@@ -325,6 +325,35 @@ std::optional<ValueRefList> subtensor_grad_rule(
inputs2.push_back(inputs[i]);
}
}
CompNodeValue::ref_t device = inputs[0].device();
auto get_subtensor_index = [&](int idx) {
HostTensorStorage storage(*device);
storage.ensure_size(dtype::Int32().size());
auto* ptr = reinterpret_cast<dt_int32*>(storage.ptr());
ptr[0] = idx;
return imperative::apply(
CreateTensor(
CreateTensor::Unique, *device, dtype::Int32(), ValueShape({1})),
HostStorage::make(storage))[0];
};
auto slice_items = subtensor.slice_items;
auto items = subtensor.items;
for (int i = 0; i < slice_items.size(); i++) {
auto&& [axis, b_flag, e_flag, s_flag, idx_flag] = items[i];
auto&& [b_val, e_val, s_val, ax_val] = slice_items[i];
if (b_flag) {
inputs2.push_back(get_subtensor_index(b_val));
};
if (e_flag) {
inputs2.push_back(get_subtensor_index(e_val));
};
if (s_flag) {
inputs2.push_back(get_subtensor_index(s_val));
};
if (idx_flag) {
inputs2.push_back(get_subtensor_index(ax_val));
};
};
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([inputs = std::move(inputs2),
......@@ -647,8 +676,9 @@ std::optional<ValueRefList> warp_affine_grad_rule(
ret[1] = imperative::apply(*grad_op, args_)[0];
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> slice_items;
items.push_back(std::make_tuple(1, true, true, false, false));
auto&& subtensor = Subtensor::make(items);
auto&& subtensor = Subtensor::make(items, slice_items);
CompNodeValue::ref_t device = inputs[0].device();
DTypeValue::ref_t dtype = inputs[0].dtype();
......
......@@ -781,14 +781,8 @@ std::pair<size_t, bool> get_ndim_safe(py::handle tensor) {
}
}
py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) {
py::tuple _unpack_indexes(py::handle inp_hdl, py::tuple tuple_val) {
py::object inp = py::reinterpret_borrow<py::object>(inp_hdl);
py::tuple tuple_val;
if (py::isinstance<py::tuple>(idx_hdl)) {
tuple_val = py::reinterpret_borrow<py::tuple>(idx_hdl);
} else {
tuple_val = py::make_tuple(idx_hdl);
}
bool use_subtensor = true;
bool need_remove_ellipsis = false;
......@@ -939,6 +933,20 @@ bool enable_fastpath(py::handle inp) {
return true;
}
bool subtensor_fastpath(py::handle inp_hdl, py::tuple tuple_val) {
bool use_fastpath = true;
for (size_t i = 0; i < tuple_val.size(); ++i) {
PyObject* obj = tuple_val[i].ptr();
if ((!is_scalar(obj) && !PySlice_Check(obj) && obj != Py_Ellipsis &&
obj != Py_None) ||
(PyObject_TypeCheck(obj, py_varnode_type))) {
use_fastpath = false;
break;
}
}
return use_fastpath && enable_fastpath(inp_hdl);
}
py::object _broadcast_cpp(py::handle input, py::handle args) {
py::object shape = _expand_args(args);
py::list dims;
......@@ -1128,16 +1136,129 @@ py::object _adaptive_pool2d_cpp(
return ret[0];
}
py::object _fastpath_getitem_cpp(py::handle inp_hdl, py::tuple tuple_val) {
py::object inp = py::reinterpret_borrow<py::object>(inp_hdl);
int ax = 0;
bool use_ellipsis = false;
size_t special_dim = 0;
for (size_t i = 0; i < tuple_val.size(); ++i) {
PyObject* obj = tuple_val[i].ptr();
if (obj == Py_Ellipsis) {
use_ellipsis = true;
for (size_t j = i + 1; j < tuple_val.size(); j++) {
PyObject* obj_last = tuple_val[j].ptr();
if (obj_last == Py_Ellipsis) {
throw py::index_error("only one ellipsis is allowed.");
}
}
}
if (obj != Py_None && obj != Py_Ellipsis && obj != Py_True && obj != Py_False) {
special_dim++;
}
}
size_t ndim = 0;
try {
ndim = getattr(inp_hdl, "ndim").cast<size_t>();
} catch (py::error_already_set& err) {
if (use_ellipsis) {
throw py::index_error(
"does not support Ellipsis when tensor's ndim is unknown.");
};
}
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> slice_items;
std::vector<int32_t> expand_items;
for (size_t i = 0; i < tuple_val.size(); ++i) {
py::object t = tuple_val[i];
if (t.ptr() == Py_Ellipsis) {
ax += ndim - special_dim;
} else if (PySlice_Check(t.ptr())) {
PySliceObject* s = (PySliceObject*)t.ptr();
std::vector<int> items;
std::vector<bool> idx_items;
auto push = [&](PyObject* v, int default_value) {
if (v == Py_None) {
items.push_back(default_value);
idx_items.push_back(false);
} else {
auto obj = py::reinterpret_borrow<py::object>(v);
items.push_back(obj.cast<int>());
idx_items.push_back(true);
}
};
push(s->start, INT_MIN);
push(s->stop, INT_MAX);
push(s->step, INT_MAX);
if (idx_items[0] || idx_items[1] || idx_items[2]) {
cpp_items.push_back(
{ax, idx_items[0], idx_items[1], idx_items[2], false});
slice_items.push_back({items[0], items[1], items[2], INT_MAX});
}
ax += 1;
} else if (PyLong_Check(t.ptr()) && !PyBool_Check(t.ptr())) {
cpp_items.push_back({ax, false, false, false, true});
slice_items.push_back({INT_MIN, INT_MAX, INT_MAX, t.cast<int>()});
ax += 1;
} else if (PyBool_Check(t.ptr())) {
expand_items.push_back(ax);
} else if (t.ptr() == Py_None) {
expand_items.push_back(ax);
ax += 1;
} else if (is_scalar(t.ptr())) {
cpp_items.push_back({ax, false, false, false, true});
slice_items.push_back({INT_MIN, INT_MAX, INT_MAX, t.cast<int>()});
ax += 1;
} else {
throw py::value_error("fast path subtensor index not impl");
}
}
if (expand_items.size()) {
std::shared_ptr<OpDef> op = AddAxis::make(expand_items);
py::object Op = py::cast(op);
PyObject* p[2] = {Op.ptr(), inp.ptr()};
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2));
inp = ret[0];
}
std::shared_ptr<OpDef> op;
op = Subtensor::make(cpp_items, slice_items);
std::vector<PyObject*> p;
p.resize(2);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = inp.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
}
py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
py::tuple try_res = _try_cond_take(inp_hdl, idx_hdl);
if (try_res.size() == 2) {
return try_res[0];
}
py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
py::tuple tuple_val;
if (py::isinstance<py::tuple>(idx_hdl)) {
tuple_val = py::reinterpret_borrow<py::tuple>(idx_hdl);
} else {
tuple_val = py::make_tuple(idx_hdl);
}
if (subtensor_fastpath(inp_hdl, tuple_val)) {
return _fastpath_getitem_cpp(inp_hdl, tuple_val);
}
py::tuple up = _unpack_indexes(inp_hdl, tuple_val);
py::object tensor = py::reinterpret_borrow<py::object>(up[0]);
py::list tensors = py::reinterpret_borrow<py::list>(up[1]);
py::list py_items = py::reinterpret_borrow<py::list>(up[2]);
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> slice_items;
for (size_t i = 0; i < py_items.size(); ++i) {
py::list item = py::reinterpret_borrow<py::list>(py_items[i]);
cpp_items.push_back(
......@@ -1146,7 +1267,7 @@ py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
}
std::shared_ptr<OpDef> op;
if (up[3].cast<bool>()) {
op = Subtensor::make(cpp_items);
op = Subtensor::make(cpp_items, slice_items);
} else {
op = IndexingMultiAxisVec::make(cpp_items);
}
......@@ -1170,11 +1291,18 @@ py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_h
val = _Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device"));
}
py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
py::tuple tuple_val;
if (py::isinstance<py::tuple>(idx_hdl)) {
tuple_val = py::reinterpret_borrow<py::tuple>(idx_hdl);
} else {
tuple_val = py::make_tuple(idx_hdl);
}
py::tuple up = _unpack_indexes(inp_hdl, tuple_val);
py::object tensor = py::reinterpret_borrow<py::object>(up[0]);
py::list tensors = py::reinterpret_borrow<py::list>(up[1]);
py::list py_items = py::reinterpret_borrow<py::list>(up[2]);
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> slice_items;
for (size_t i = 0; i < py_items.size(); ++i) {
py::list item = py::reinterpret_borrow<py::list>(py_items[i]);
cpp_items.push_back(
......@@ -1183,7 +1311,7 @@ py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_h
}
std::shared_ptr<OpDef> op, set_op;
if (up[3].cast<bool>()) {
op = Subtensor::make(cpp_items);
op = Subtensor::make(cpp_items, slice_items);
} else {
op = IndexingMultiAxisVec::make(cpp_items);
}
......
......@@ -208,6 +208,9 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
cg::copy_tensor_value_to_shape(
tshp, tshp_nd->get_value().proxy_to_default_cpu());
}
if (tshp.is_empty()) {
return {Tensor::make(TensorLayout(tshp, src->dtype()), src->comp_node())};
}
TensorLayout tlayout = slayout.broadcast(tshp);
// memory forward
return {Tensor::make(src->blob(), src->offset(), tlayout)};
......
......@@ -103,8 +103,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
if (!src.ndim) {
return {{{{{}, src.dtype}, comp_node}}, false};
}
mgb_assert(src.is_contiguous(), "src should be contiguous");
return {{{src, comp_node}}, true};
return {{{{src, src.dtype}, comp_node}}, true};
}
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
......@@ -138,10 +137,18 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
dnn_op.exec_with_ws(out, index, sub);
return {out};
}
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
layout_checker[0] = layout_checker[1] = layout_checker[2] =
[](const TensorLayout& layout) { return layout.is_contiguous(); };
return layout_checker;
}
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_var_node(apply_on_var_node)
.get_input_layout_constraint(get_input_layout_constraint)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace indexing_set_one_hot
......
......@@ -542,7 +542,6 @@ auto get_index(
OP_TRAIT_REG(NAME, NAME).apply_on_var_node(apply_on_var_node).fallback(); \
}
FANCY_INDEXING_IMPL(Subtensor, 1)
FANCY_INDEXING_IMPL(SetSubtensor, 2)
FANCY_INDEXING_IMPL(IncrSubtensor, 2)
FANCY_INDEXING_IMPL(IndexingMultiAxisVec, 1)
......
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/internal/indexing_helper.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/tensor.h"
#include "../algo_chooser.h"
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
using namespace mgb::opr::indexing;
namespace mgb::imperative {
namespace {
namespace subtensor {
auto get_index(
const VarNodeArray& inputs,
const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask,
const std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>>& slice) {
size_t length = mask.size();
auto graph = inputs[0]->owner_graph();
auto comp_node = inputs[0]->comp_node();
opr::Subtensor::IndexDesc ret(length);
auto immutable_node = [&](int val) {
DTypeScalar scalar = DTypeScalar(static_cast<megdnn::dt_int32>(val));
return opr::ImmutableTensor::make(*graph, scalar, {comp_node});
};
for (size_t i = 0; i < length; ++i) {
auto&& [axis, b_flag, e_flag, s_flag, idx_flag] = mask[i];
auto&& [b_val, e_val, s_val, ax_val] = slice[i];
ret[i].axis = axis;
if (idx_flag) {
ret[i].idx = immutable_node(ax_val);
} else {
if (b_flag) {
ret[i].begin = immutable_node(b_val);
}
if (e_flag) {
ret[i].end = immutable_node(e_val);
}
if (s_flag) {
ret[i].step = immutable_node(s_val);
}
}
}
return ret;
}
auto origin_get_index(
const VarNodeArray& inputs, size_t vidx,
const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) {
size_t length = mask.size();
opr::Subtensor::IndexDesc ret(length);
for (size_t i = 0; i < length; ++i) {
auto&& [axis, begin, end, step, idx] = mask[i];
ret[i].axis = axis;
if (idx) {
ret[i].idx = inputs[vidx++];
} else {
mgb_assert(begin || end || step);
if (begin)
ret[i].begin = inputs[vidx++];
if (end)
ret[i].end = inputs[vidx++];
if (step)
ret[i].step = inputs[vidx++];
}
}
mgb_assert(vidx == inputs.size());
return ret;
}
TensorLayout deduce_layout(
TensorLayout src, std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items,
std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> slice_items) {
auto mod_size = [](int v, int size_ax) -> int {
if (size_ax == 0)
return 0;
return v < 0 ? v + size_ax : v;
};
#define CHECK(cond) \
mgb_assert(cond, "index out of bound: layout=%s", src.to_string().c_str())
for (int i = items.size() - 1; i >= 0; i--) {
auto&& [axis, b_flag, e_flag, s_flag, idx_flag] = items[i];
auto&& [b_val, e_val, s_val, ax_val] = slice_items[i];
int shape_axis = src.shape[axis];
int slice_step = s_val == INT_MAX ? 1 : s_val;
int slice_start = b_val == INT_MIN ? 0 : b_val;
int slice_stop = e_val == INT_MAX ? shape_axis : e_val;
if (slice_step > 0) {
slice_start = mod_size(slice_start, shape_axis);
slice_stop = mod_size(slice_stop, shape_axis);
slice_stop = std::min(slice_stop, shape_axis);
slice_start = std::min(slice_start, slice_stop);
CHECK(slice_start >= 0 && slice_stop >= slice_start &&
slice_stop <= shape_axis);
} else {
slice_start = s_val == INT_MIN ? shape_axis - 1 : b_val;
slice_start = mod_size(slice_start, shape_axis);
slice_stop = e_val == INT_MAX ? -1 : mod_size(e_val, shape_axis);
slice_start = std::min(slice_start, std::max(shape_axis - 1, 0));
slice_stop = std::min(slice_stop, slice_start);
CHECK(slice_step < 0 && slice_start >= 0 && slice_stop <= slice_start &&
slice_start < shape_axis && slice_stop >= -1);
}
int abs_step = std::abs(slice_step);
if (axis < 0) {
axis = axis + src.ndim;
};
if (idx_flag == true) {
if (src.ndim == 1) {
src.shape[0] = 1;
} else {
src.remove_axis_inplace(axis);
}
} else {
src.shape[axis] =
(std::abs(slice_stop - slice_start) + abs_step - 1) / abs_step;
src.stride[axis] *= slice_step;
}
}
return src;
}
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Subtensor&>(def);
OperatorNodeConfig config{op.make_name()};
if (inputs.size() > 1) {
return opr::Subtensor::make(
inputs[0], origin_get_index(inputs, 1, op.items), config);
} else {
return opr::Subtensor::make(
inputs[0], get_index(inputs, op.items, op.slice_items), config);
}
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
if (inputs.size() >= 2) {
return proxy_graph_detail::infer_output_attrs_fallible(def, inputs);
}
auto&& inp = inputs[0];
auto& inp_cn = inp.comp_node;
if (inp.layout.ndim == 0) {
return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}}}, false};
}
auto&& op = static_cast<const Subtensor&>(def);
auto items = op.items;
auto slice_itmes = op.slice_items;
TensorLayout out_layout = deduce_layout(inp.layout, items, slice_itmes);
return {{{out_layout, inp_cn, {}}}, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
CompNode cn = inputs[0]->comp_node();
auto&& layout = inputs[0]->layout();
auto&& op = static_cast<const Subtensor&>(def);
if (inputs.size() > 1) {
return proxy_graph_detail::apply_on_physical_tensor(
def, inputs, output_descs, validated);
}
auto&& src = inputs[0];
auto slice_items = op.slice_items;
auto items = op.items;
TensorLayout res_layout = deduce_layout(layout, items, slice_items);
if (res_layout.is_empty()) {
return {Tensor::make(res_layout, cn)};
}
size_t offset = 0;
size_t dtype_size = layout.dtype.size();
TensorPtr tensor = src;
for (int i = items.size() - 1; i >= 0; i--) {
auto&& [axis, b_flag, e_flag, s_flag, idx_flag] = items[i];
auto&& [b_val, e_val, s_val, ax_val] = slice_items[i];
int start = b_val;
if (idx_flag) {
ax_val = ax_val < 0 ? layout.shape[axis] + ax_val : ax_val;
offset += ax_val * layout.stride[axis] * dtype_size;
} else {
start = std::max(start, 0);
offset += start * layout.stride[axis] * dtype_size;
}
}
// memory forward
return {Tensor::make(src->blob(), src->offset() + offset, res_layout)};
}
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
return layout_checker;
}
OP_TRAIT_REG(Subtensor, Subtensor, opr::Subtensor)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.get_input_layout_constraint(get_input_layout_constraint)
.fallback();
} // namespace subtensor
} // namespace
} // namespace mgb::imperative
\ No newline at end of file
......@@ -306,6 +306,19 @@ inline bool is_reduce_ndim_idx_items(
return false;
}
inline bool is_subtensor_reduce_ndim(
const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& items,
const std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> slice_items) {
for (auto i = 0; i < items.size(); ++i) {
auto&& [axis, begin, end, step, idx] = items[i];
if (idx) {
auto&& [b_val, e_val, s_val, ax_val] = slice_items[i];
return ax_val != INT_MAX;
}
}
return false;
}
inline auto convert_nchw2nhwc_idx_items(
const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& items) {
auto nhwc_items = items;
......@@ -326,6 +339,34 @@ ValueRefList subtensor_rule(
const FormatTransformation& t) {
mgb_assert(inputs.size() >= 1);
auto& src = inputs[0].cast(t.value_type());
bool is_reduce_ndim = false;
if (inputs.size() > 1) {
is_reduce_ndim = is_reduce_ndim_idx_items(
op.items, {&inputs[1], &inputs[inputs.size() - 1]});
} else {
is_reduce_ndim = is_subtensor_reduce_ndim(op.items, op.slice_items);
}
if (!is_reduce_ndim) {
// only support NHWC2NCHW convert, otherwise maintain src's format
if (!(auto_convert && src.format() == FT::NHWC)) {
return {t.wrap_output(
imperative::apply(op, t.unwrap_inputs(inputs))[0], src.format())};
}
auto nhwc_items = convert_nchw2nhwc_idx_items(op.items);
auto outputs = imperative::apply(
*T::make(std::move(nhwc_items), op.slice_items, op.scope()),
t.unwrap_inputs(inputs));
return t.wrap_outputs(outputs, FT::NHWC);
}
return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)));
}
template <typename T>
ValueRefList indexing_rule(
const T& op, Span<ValueRef>& inputs, const bool& auto_convert,
const FormatTransformation& t) {
mgb_assert(inputs.size() >= 1);
auto& src = inputs[0].cast(t.value_type());
bool is_reduce_ndim = is_reduce_ndim_idx_items(
op.items, {&inputs[1], &inputs[inputs.size() - 1]});
if (!is_reduce_ndim) {
......@@ -597,7 +638,7 @@ struct FormatRuleRegistry {
register_format_rule(reshape_rule);
register_format_rule(broadcast_rule);
register_format_rule(subtensor_rule<Subtensor>);
register_format_rule(subtensor_rule<IndexingMultiAxisVec>);
register_format_rule(indexing_rule<IndexingMultiAxisVec>);
register_format_rule(setsubtensor_rule<SetSubtensor>);
register_format_rule(setsubtensor_rule<IndexingSetMultiAxisVec>);
register_format_rule(elemwise_rule);
......
8dd504f360fd3d3bfb560c970b568153 ../../dnn/scripts/opr_param_defs.py
cf864561de125ab559c0035158656682 ../../src/core/include/megbrain/ir/ops.td
f27cdbb7926e0be9f5dabb8651d2e4da generated/opdef.h.inl
96817f709ee92c8e1eb7cb4168f28565 generated/opdef.cpp.inl
672668fa3ed11c27781f0fa380e6c8aa generated/opdef.py.inl
47511e3e7fed8c64a1c4fea48d79b3d1 generated/opdef.cpy.inl
6811fde221f86d1ef8de425a3c83127b ../../src/core/include/megbrain/ir/ops.td
55123da1605ef6edd79e3a2ede8aefeb generated/opdef.h.inl
6f4beb6d12cdd9ec4c4e61b6d7d35144 generated/opdef.cpp.inl
185ba3c3a0fce480ee498cef058670b2 generated/opdef.py.inl
b7ed7a638b7586709bb23dd153fb58b1 generated/opdef.cpy.inl
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h
......@@ -6918,6 +6918,7 @@ size_t Subtensor_hash_impl(const OpDef& def_) {
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
val = mgb::hash_pair_combine(val, mgb::hash(op_.items));
val = mgb::hash_pair_combine(val, mgb::hash(op_.slice_items));
return val;
}
bool Subtensor_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
......@@ -6926,6 +6927,7 @@ bool Subtensor_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
static_cast<void>(a_);
static_cast<void>(b_);
if (a_.items != b_.items) return false;
if (a_.slice_items != b_.slice_items) return false;
return true;
}
std::vector<std::pair<const char*, std::string>> Subtensor_props_impl(const OpDef& def_) {
......@@ -6933,6 +6935,7 @@ std::vector<std::pair<const char*, std::string>> Subtensor_props_impl(const OpDe
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
props_.emplace_back("items", "{std::vector}");
props_.emplace_back("slice_items", "{std::vector}");
return props_;
}
std::string Subtensor_make_name_impl(const OpDef& def_) {
......
......@@ -20113,7 +20113,8 @@ PyOpDefBegin(Subtensor) // {
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
{"items", serialization<decltype(opdef.items)>::dump(opdef.items)}
{"items", serialization<decltype(opdef.items)>::dump(opdef.items)},
{"slice_items", serialization<decltype(opdef.slice_items)>::dump(opdef.slice_items)}
};
return py::cast(state).release().ptr();
}
......@@ -20130,6 +20131,13 @@ PyOpDefBegin(Subtensor) // {
opdef.items = serialization<decltype(opdef.items)>::load(iter->second);
}
}
{
auto&& iter = state.find("slice_items");
if (iter != state.end()) {
opdef.slice_items = serialization<decltype(opdef.slice_items)>::load(iter->second);
}
}
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
......@@ -20139,9 +20147,9 @@ PyOpDefBegin(Subtensor) // {
PyOpDefEnd(Subtensor)
int PyOp(Subtensor)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"items", "scope", NULL};
PyObject *items = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO", const_cast<char**>(kwlist), &items, &scope))
static const char* kwlist[] = {"items", "slice_items", "scope", NULL};
PyObject *items = NULL, *slice_items = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOO", const_cast<char**>(kwlist), &items, &slice_items, &scope))
return -1;
if (items) {
......@@ -20153,6 +20161,15 @@ int PyOp(Subtensor)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
} CATCH_ALL(-1)
}
if (slice_items) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(Subtensor)*>(self)->inst().slice_items =
py::cast<decltype(Subtensor::slice_items)>(py::handle(slice_items));
} CATCH_ALL(-1)
}
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
......@@ -20165,6 +20182,7 @@ int PyOp(Subtensor)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
PyGetSetDef PyOp(Subtensor)::py_getsetters[] = {
{const_cast<char*>("items"), py_get_generic(Subtensor, items), py_set_generic(Subtensor, items), const_cast<char*>("items"), NULL},
{const_cast<char*>("slice_items"), py_get_generic(Subtensor, slice_items), py_set_generic(Subtensor, slice_items), const_cast<char*>("slice_items"), NULL},
{NULL} /* Sentinel */
};
......@@ -20185,7 +20203,7 @@ PyMethodDef PyOp(Subtensor)::py_init_methoddef = {
"__init__",
(PyCFunction)PyOp(Subtensor)::py_init_proxy,
METH_VARARGS | METH_KEYWORDS,
"__init__(self, items: list[tuple[int, bool, bool, bool, bool]] = ...) -> None\n"
"__init__(self, items: list[tuple[int, bool, bool, bool, bool]] = ..., slice_items: list[tuple[int, int, int, int]] = ...) -> None\n"
};
void _init_py_Subtensor(py::module m) {
......
......@@ -1795,8 +1795,9 @@ class Subtensor : public OpDefImplBase<Subtensor> {
public:
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> slice_items;
Subtensor() = default;
Subtensor(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
Subtensor(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> slice_items_, std::string scope_ = {}): items(items_), slice_items(slice_items_) { set_scope(scope_); }
};
class TQT : public OpDefImplBase<TQT> {
......
......@@ -1882,9 +1882,10 @@ SplitInst
py::class_<Subtensor, std::shared_ptr<Subtensor>, OpDef> SubtensorInst(m, "Subtensor");
SubtensorInst
.def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
.def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>>, std::string>(), py::arg("items"), py::arg("slice_items"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("items", &Subtensor::items);
.def_readwrite("items", &Subtensor::items)
.def_readwrite("slice_items", &Subtensor::slice_items);
py::class_<TQT, std::shared_ptr<TQT>, OpDef> TQTInst(m, "TQT");
......
......@@ -380,7 +380,15 @@ class FancyIndexingBase<string name>: MgbHashableOp<name> {
);
}
def Subtensor: FancyIndexingBase<"Subtensor">;
def Subtensor: MgbHashableOp<"Subtensor"> {
let extraArguments = (ins
MgbArrayAttr<MgbTupleAttr<
[MgbI8Attr, MgbBoolAttr, MgbBoolAttr, MgbBoolAttr, MgbBoolAttr]>>:$items,
MgbArrayAttr<MgbTupleAttr<[MgbI32Attr, MgbI32Attr, MgbI32Attr, MgbI32Attr]>>:$slice_items
);
}
// def Subtensor: FancyIndexingBase<"Subtensor">;
def SetSubtensor: FancyIndexingBase<"SetSubtensor">;
def IncrSubtensor: FancyIndexingBase<"IncrSubtensor">;
def IndexingMultiAxisVec: FancyIndexingBase<"IndexingMultiAxisVec">;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册