提交 9a6ba334 编写于 作者: M Megvii Engine Team

feat(imperative): speed up concat and stack

GitOrigin-RevId: 614e87171908f419f98eb4c150c0a10a13f85c60
上级 489af281
......@@ -489,9 +489,9 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
return inps[0]
if device is None:
device = get_device(inps)
device = as_device(device)
(result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps)
device = get_default_device()
(result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inps)
return result
......@@ -516,13 +516,12 @@ def stack(inps, axis=0, device=None):
array([[0., 1., 2.],
[6., 7., 8.]], dtype=float32)
"""
if len(inps) > 0 and not isinstance(inps[0].shape, inps[0].__class__):
shapes = {arr.shape for arr in inps}
if len(shapes) != 1:
raise ValueError("All input tensors must have the same shape")
inps = [expand_dims(inp, axis=axis) for inp in inps]
return concat(inps, axis=axis, device=device)
if len(inps) == 1:
return expand_dims(inps[0], axis=axis)
if device is None:
device = get_default_device()
(result,) = apply(builtin.Stack(axis=axis, comp_node=device), *inps)
return result
def split(inp, nsplits_or_sections, axis=0):
......
......@@ -5,21 +5,29 @@ cd $(dirname $0)/..
ISORT_ARG=""
BLACK_ARG=""
while getopts 'd' OPT; do
while getopts 'dt:' OPT; do
case $OPT in
d)
ISORT_ARG="--diff --check-only"
BLACK_ARG="--diff --check"
;;
t)
TARGET=$OPTARG
;;
?)
echo "Usage: `basename $0` [-d]"
esac
done
directories=(megengine test)
if [[ -d examples ]]; then
directories+=(examples)
if [[ $TARGET ]]; then
directories=($TARGET)
else
directories=(megengine test)
if [[ -d examples ]]; then
directories+=(examples)
fi
fi
# do not isort megengine/__init__.py file, caused we must
# init library load path before load dependent lib in core
isort $ISORT_ARG -j $(nproc) -rc "${directories[@]}" -s megengine/__init__.py
......
......@@ -176,9 +176,12 @@ PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) {
return res;
}
// if all the inputs are not megengine tensor, return get_default_device()
// else check whether all input tensors have the same device
CompNode _get_device(PyObject* const* args, size_t nargs) {
bool is_tuple = false;
PyObject* tuple = nullptr;
// convert input args to a tuple
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
if (PyList_Check(args[0])) {
tuple = PyList_AsTuple(args[0]);
......
......@@ -217,6 +217,55 @@ def test_split_basic(is_varnode):
set_symbolic_shape(saved_symbolic_shape)
def test_concat_and_stack():
import copy
def generate_test_data(max_nr_inp, max_dim, max_dim_len, test_concat=True):
nr_inp = np.random.randint(1, max_nr_inp)
dims = np.random.randint(1, max_dim)
cat_axis = (
np.random.randint(-dims, dims)
if test_concat
else np.random.randint(-dims - 1, dims + 1)
)
ishape = [np.random.randint(0, max_dim_len) for _ in range(dims)]
ishapes = [copy.deepcopy(ishape) for _ in range(nr_inp)]
if test_concat:
for i in range(nr_inp):
ishapes[i][cat_axis] = np.random.randint(0, max_dim_len)
inp_nps = []
for ishape in ishapes:
inp_nps.append(np.random.randn(*ishape))
return inp_nps, cat_axis
def test_impl(max_nr_inp, max_dim, max_dim_len, test_concat):
inp_nps, cat_axis = generate_test_data(
max_nr_inp, max_dim, max_dim_len, test_concat
)
inp_mges = [Tensor(inp_np) for inp_np in inp_nps]
if test_concat:
np_func, mge_func = np.concatenate, F.concat
else:
np_func, mge_func = np.stack, F.stack
res_np = np_func(inp_nps, axis=cat_axis)
res_mge = mge_func(inp_mges, axis=cat_axis)
np.testing.assert_allclose(res_mge.numpy(), res_np)
def test_concat(max_nr_inp, max_dim, max_dim_len):
test_impl(max_nr_inp, max_dim, max_dim_len, test_concat=True)
def test_stack(max_nr_inp, max_dim, max_dim_len):
test_impl(max_nr_inp, max_dim, max_dim_len, test_concat=False)
for _ in range(3):
test_concat(10, 7, 16)
for _ in range(3):
test_stack(10, 7, 16)
@pytest.mark.parametrize("symbolic", [None, False, True])
def test_split(symbolic):
x = Tensor(np.random.random((10, 20)), dtype=np.float32)
......
......@@ -28,7 +28,7 @@ public:
// FIXME: maybe in-place style deduction works better
template <typename... TArgs>
TensorLayout deduce_layout(TArgs&&... args) {
static_assert((std::is_convertible_v<TArgs, TensorLayout> && ...));
// static_assert((std::is_convertible_v<TArgs, TensorLayout> && ...));
TensorLayout output_layout;
m_opr->deduce_layout(args..., output_layout);
return output_layout;
......@@ -36,7 +36,7 @@ public:
template <typename... TArgs>
TensorLayout deduce_layout_fallible(TArgs&&... args) {
static_assert((std::is_convertible_v<TArgs, TensorLayout> && ...));
// static_assert((std::is_convertible_v<TArgs, TensorLayout> && ...));
TensorLayout output_layout;
bool success = (args.ndim * ...) > 0;
if (success) {
......@@ -49,7 +49,7 @@ public:
template <size_t nr_outputs, typename... TArgs>
std::array<TensorLayout, nr_outputs> deduce_layouts(TArgs&&... args) {
static_assert((std::is_convertible_v<TArgs, TensorLayout> && ...));
// static_assert((std::is_convertible_v<TArgs, TensorLayout> && ...));
std::array<TensorLayout, nr_outputs> layouts;
std::apply(
[&](auto&&... outputs) { m_opr->deduce_layout(args..., outputs...); },
......
#include <climits>
#include "../dnn_op_helper.h"
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/utils/stats.h"
namespace mgb::imperative {
namespace {
template <typename Opr>
CompNode get_device(const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<Opr>();
const char* op_name = op_def.make_name().c_str();
CompNode oup_cn = op_def.comp_node;
if (!oup_cn.valid()) {
size_t nr_inp = inputs.size();
mgb_assert(
nr_inp > 0, "number of inputs of %s should be greater than 0", op_name);
auto&& inp_cn = inputs[0].comp_node;
for (size_t i = 1; i < nr_inp; ++i) {
mgb_assert(
inp_cn == inputs[i].comp_node,
"input tensors of %s operator should have same device, but get "
"%s vs %s",
op_name, inp_cn.to_string().c_str(),
inputs[i].comp_node.to_string().c_str());
}
oup_cn = inp_cn;
}
return oup_cn;
}
bool is_all_inputs_valid(const SmallVector<LogicalTensorDesc>& inputs) {
bool input_valid = true;
size_t nr_inp = inputs.size();
for (size_t i = 0; i < nr_inp; ++i) {
if (inputs[i].layout.ndim == 0) {
input_valid = false;
break;
}
}
return input_valid;
}
} // namespace
namespace concatenate {
TensorLayout concat_layout_deduce(
const SmallVector<const TensorLayout*> inputs, int axis) {
// if we use megdnn::Concat::deduce_layout directly, we need construct
// TensorLayoutArray, which will result in much memory copy
auto shape_equal_but_specific_axis = [](const TensorShape& lhs,
const TensorShape& rhs, int axis) -> bool {
if (lhs.ndim != rhs.ndim) {
return false;
}
for (size_t i = 0; i < lhs.ndim; ++i) {
if (i == axis)
continue;
if (lhs.shape[i] != rhs.shape[i])
return false;
}
return true;
};
TensorLayout oup_layout = *inputs[0];
for (size_t i = 1; i < inputs.size(); ++i) {
mgb_assert(
shape_equal_but_specific_axis(oup_layout, *inputs[i], axis),
"Concat input shape mismatch: %s vs %s", inputs[0]->to_string().c_str(),
inputs[i]->to_string().c_str());
oup_layout.shape[axis] += inputs[i]->shape[axis];
}
oup_layout.init_contiguous_stride();
return oup_layout;
}
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Concat&>(def);
cg::OperatorNodeConfig config{op.comp_node};
config.name(op.make_name());
return opr::Concat::make(inputs, op.axis, config);
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<Concat>();
auto oup_cn = get_device<Concat>(def, inputs);
if (!is_all_inputs_valid(inputs)) {
// because dtypepromote_trans, so use inputs[0].dtype as oup_dtype here
return {{{TensorLayout{inputs[0].layout.dtype}, oup_cn, {}}}, false};
}
SmallVector<const TensorLayout*> inputs_holder(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
inputs_holder[i] = &inputs[i].layout;
}
int axis = op_def.axis >= 0 ? op_def.axis : op_def.axis + inputs[0].layout.ndim;
TensorLayout oup_layout = concat_layout_deduce(inputs_holder, axis);
return {{{oup_layout, oup_cn, {}}}, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op_def = def.cast_final_safe<Concat>();
int axis = op_def.axis >= 0 ? op_def.axis : op_def.axis + inputs[0]->layout().ndim;
CompNode& oup_cn = output_descs[0].comp_node;
if (op_def.comp_node.valid()) {
mgb_assert(op_def.comp_node == oup_cn, "Concat compnode infer error");
}
// prepare inputs and output layout
TensorLayout& oup_layout = output_descs[0].layout;
if (!validated) {
SmallVector<const TensorLayout*> inputs_holder(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
inputs_holder[i] = &inputs[i]->layout();
}
oup_layout = concat_layout_deduce(inputs_holder, axis);
}
auto oup = Tensor::make(oup_layout, oup_cn);
// because the dnn concat is very slow, we copy the slice code from
// src/opr/impl/tensor_manip.cpp
auto&& out = oup->dev_tensor();
size_t end = 0;
for (auto&& input : inputs) {
auto&& in = input->dev_tensor();
auto begin = end;
end = begin + in.shape().shape[axis];
if (!in.layout().is_empty()) {
out.sub(Slice(begin, end).apply(out.layout(), axis))
.copy_from_fixlayout(in);
}
}
return {oup};
}
OP_TRAIT_REG(Concat, Concat)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace concatenate
namespace stack {
TensorLayout stack_layout_deduce(
const SmallVector<const TensorLayout*> inputs, int axis) {
size_t nr_inp = inputs.size();
auto&& inp_layout0 = *inputs[0];
for (size_t i = 1; i < nr_inp; ++i) {
mgb_assert(
inp_layout0.eq_shape(*inputs[i]),
"Stack input shape mismatch: %s vs %s", inp_layout0.to_string().c_str(),
inputs[i]->to_string().c_str());
}
TensorLayout oup_layout{TensorShape{inp_layout0}, inp_layout0.dtype};
oup_layout.add_axis_cont_inplace(axis);
oup_layout.shape[axis] = nr_inp;
oup_layout.init_contiguous_stride();
return oup_layout;
}
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Stack&>(def);
cg::OperatorNodeConfig config{op.comp_node};
config.name(op.make_name());
using Desc = opr::AxisAddRemove::AxisDesc;
std::vector<Desc> param{Desc::make_add(op.axis)};
VarNodeArray expanded_inputs;
for (auto&& inp : inputs) {
expanded_inputs.emplace_back(
opr::AxisAddRemove::make(inp, param, cg::OperatorNodeConfig{}).node());
}
return opr::Concat::make(expanded_inputs, op.axis, config);
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<Stack>();
auto oup_cn = get_device<Stack>(def, inputs);
if (!is_all_inputs_valid(inputs)) {
// because dtypepromote_trans, so use inputs[0].dtype as oup_dtype here
return {{{TensorLayout{inputs[0].layout.dtype}, oup_cn, {}}}, false};
}
SmallVector<const TensorLayout*> inputs_holder(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
inputs_holder[i] = &inputs[i].layout;
}
int axis = op_def.axis >= 0 ? op_def.axis : op_def.axis + inputs[0].layout.ndim + 1;
TensorLayout oup_layout = stack_layout_deduce(inputs_holder, axis);
return {{{oup_layout, oup_cn, {}}}, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op_def = def.cast_final_safe<Stack>();
size_t nr_inp = inputs.size();
TensorLayout inp_layout = inputs[0]->layout();
int axis =
op_def.axis >= 0 ? op_def.axis : op_def.axis + inputs[0]->layout().ndim + 1;
CompNode& oup_cn = output_descs[0].comp_node;
if (op_def.comp_node.valid()) {
mgb_assert(op_def.comp_node == oup_cn, "Stack compnode infer error");
}
// prepare inputs and output layout
TensorLayout& oup_layout = output_descs[0].layout;
if (!validated) {
SmallVector<const TensorLayout*> inputs_holder(inputs.size());
for (size_t i = 0; i < nr_inp; ++i) {
inputs_holder[i] = &inputs[i]->layout();
}
oup_layout = stack_layout_deduce(inputs_holder, axis);
}
inp_layout.add_axis_cont_inplace(axis);
SmallVector<TensorPtr> expanded;
for (size_t i = 0; i < nr_inp; ++i) {
expanded.push_back(
Tensor::make(inputs[i]->blob(), inputs[i]->offset(), inp_layout));
}
auto oup = Tensor::make(oup_layout, oup_cn);
// because the dnn concat is very slow, we copy the slice code from
// src/opr/impl/tensor_manip.cpp
auto&& out = oup->dev_tensor();
size_t end = 0;
for (auto&& input : expanded) {
auto&& in = input->dev_tensor();
auto begin = end;
end = begin + in.shape().shape[axis];
if (!in.layout().is_empty()) {
out.sub(Slice(begin, end).apply(out.layout(), axis))
.copy_from_fixlayout(in);
}
}
return {oup};
}
OP_TRAIT_REG(Stack, Stack)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace stack
} // namespace mgb::imperative
......@@ -384,18 +384,6 @@ OP_TRAIT_REG(TypeCvt, TypeCvt).apply_on_var_node(apply_on_var_node).fallback();
} // namespace typecvt
} // namespace
namespace {
namespace concat {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Concat&>(def);
cg::OperatorNodeConfig config{op.comp_node};
config.name(op.make_name());
return opr::Concat::make(inputs, op.axis, config);
}
OP_TRAIT_REG(Concat, Concat).apply_on_var_node(apply_on_var_node).fallback();
} // namespace concat
} // namespace
namespace {
namespace copy {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
......
......@@ -53,7 +53,7 @@ SmallVector<LayoutConstraintCallback> get_input_layout_constraint(
VarNodeArray vinputs(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
OperatorNodeConfig config;
auto&& layout = inputs[i]->layout();
auto layout = inputs[i]->layout();
layout.init_contiguous_stride();
vinputs[i] = graph->insert_opr(std::make_unique<mgb::opr::SharedDeviceTensor>(
*graph,
......
......@@ -391,6 +391,7 @@ struct DTypePromoteRuleRegistry {
register_dtype_promote_rule<Elemwise>(elemwise_rule);
register_dtype_promote_rule<ElemwiseMultiType>(elemwise_multi_type_rule);
register_dtype_promote_rule<Concat>(naive_promote_rule);
register_dtype_promote_rule<Stack>(naive_promote_rule);
register_dtype_promote_rule<GroupLocal>(naive_promote_rule);
register_dtype_promote_rule<Reduce>(reduce_rule);
register_dtype_promote_rule<Convolution>(convolution_rule);
......
......@@ -133,7 +133,7 @@ public:
DType dtype() const { return m_dtype; }
TensorLayout layout() const { return m_layout; }
const TensorLayout& layout() const { return m_layout; }
const TensorShape& shape() const { return m_shape; }
......
8dd504f360fd3d3bfb560c970b568153 ../../dnn/scripts/opr_param_defs.py
06e8a3af239b545470b38b3e82960935 ../../src/core/include/megbrain/ir/ops.td
7f37497cffb24554073cbc42b89e2db8 generated/opdef.h.inl
1e2041f6374e48d53762ddfe7a6ebca3 generated/opdef.cpp.inl
9a813355a742330e9ba6e5c14ea67c7c generated/opdef.py.inl
8d4ae7fef8234d8c79ac52017f4710e3 generated/opdef.cpy.inl
7d6df1c8e50a22ef2c36b7ea89daa9c5 ../../src/core/include/megbrain/ir/ops.td
f30ae9494b4bf3363cd74d9396acaf49 generated/opdef.h.inl
cb27f486b28a099221f38c6fcaa06a44 generated/opdef.cpp.inl
adb758acd1147f213db7f0cb1b708773 generated/opdef.py.inl
30ad8e75a5994edf9ec46387c6285312 generated/opdef.cpy.inl
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h
......@@ -6993,6 +6993,46 @@ OP_TRAIT_REG(Split, Split)
.props(Split_props_impl)
.make_name(Split_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Stack);
namespace {
size_t Stack_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Stack>();
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
val = mgb::hash_pair_combine(val, mgb::hash(op_.axis));
val = mgb::hash_pair_combine(val, mgb::hash(op_.comp_node));
return val;
}
bool Stack_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<Stack>(),
&&b_ = rhs_.cast_final_safe<Stack>();
static_cast<void>(a_);
static_cast<void>(b_);
if (a_.axis != b_.axis) return false;
if (a_.comp_node != b_.comp_node) return false;
return true;
}
std::vector<std::pair<const char*, std::string>> Stack_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Stack>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
props_.emplace_back("axis", std::to_string(op_.axis));
props_.emplace_back("comp_node", op_.comp_node.to_string());
return props_;
}
std::string Stack_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<Stack>();
static_cast<void>(op_);
return "Stack";
}
} // anonymous namespace
OP_TRAIT_REG(Stack, Stack)
.hash(Stack_hash_impl)
.is_same_st(Stack_is_same_st_impl)
.props(Stack_props_impl)
.make_name(Stack_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Subtensor);
namespace {
......
......@@ -20376,6 +20376,133 @@ void _init_py_Split(py::module m) {
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Split::typeinfo(), &py_type).second);
}
PyOpDefBegin(Stack) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(Stack)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
{"axis", serialization<decltype(opdef.axis)>::dump(opdef.axis)},
{"comp_node", serialization<decltype(opdef.comp_node)>::dump(opdef.comp_node)}
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(Stack)*>(self)->inst();
static_cast<void>(opdef);
{
auto&& iter = state.find("axis");
if (iter != state.end()) {
opdef.axis = serialization<decltype(opdef.axis)>::load(iter->second);
}
}
{
auto&& iter = state.find("comp_node");
if (iter != state.end()) {
opdef.comp_node = serialization<decltype(opdef.comp_node)>::load(iter->second);
}
}
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds);
static PyMethodDef py_init_methoddef;
// };
PyOpDefEnd(Stack)
int PyOp(Stack)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"axis", "comp_node", "scope", NULL};
PyObject *axis = NULL, *comp_node = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOO", const_cast<char**>(kwlist), &axis, &comp_node, &scope))
return -1;
if (axis) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(Stack)*>(self)->inst().axis =
py::cast<decltype(Stack::axis)>(py::handle(axis));
} CATCH_ALL(-1)
}
if (comp_node) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(Stack)*>(self)->inst().comp_node =
py::cast<decltype(Stack::comp_node)>(py::handle(comp_node));
} CATCH_ALL(-1)
}
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(py::cast<std::string>(py::handle(scope)));
} CATCH_ALL(-1)
}
return 0;
}
PyGetSetDef PyOp(Stack)::py_getsetters[] = {
{const_cast<char*>("axis"), py_get_generic(Stack, axis), py_set_generic(Stack, axis), const_cast<char*>("axis"), NULL},
{const_cast<char*>("comp_node"), py_get_generic(Stack, comp_node), py_set_generic(Stack, comp_node), const_cast<char*>("comp_node"), NULL},
{NULL} /* Sentinel */
};
PyMethodDef PyOp(Stack)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(Stack)::getstate, METH_NOARGS, "Stack getstate"},
{const_cast<char*>("__setstate__"), PyOp(Stack)::setstate, METH_VARARGS, "Stack setstate"},
{NULL} /* Sentinel */
};
PyObject *PyOp(Stack)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) {
if (PyOp(Stack)::py_init(self, args, kwds) < 0) {
return NULL;
}
Py_RETURN_NONE;
}
PyMethodDef PyOp(Stack)::py_init_methoddef = {
"__init__",
(PyCFunction)PyOp(Stack)::py_init_proxy,
METH_VARARGS | METH_KEYWORDS,
"__init__(self, axis: int = ..., comp_node: str = ...) -> None\n"
};
void _init_py_Stack(py::module m) {
using py_op = PyOp(Stack);
auto& py_type = PyOpType(Stack);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.Stack";
py_type.tp_basicsize = sizeof(PyOp(Stack));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "Stack";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
py_type.tp_dict = PyDict_New();
PyObject* descr = PyDescr_NewMethod(&PyOpType(Stack), &PyOp(Stack)::py_init_methoddef);
PyDict_SetItemString(py_type.tp_dict, "__init__", descr);
mgb_assert(PyType_Ready(&py_type) >= 0);
PyType_Modified(&py_type);
m.add_object("Stack", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Stack::typeinfo(), &py_type).second);
}
PyOpDefBegin(Subtensor) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
......@@ -22064,6 +22191,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
_init_py_SlidingWindowTranspose(m); \
_init_py_Softmax(m); \
_init_py_Split(m); \
_init_py_Stack(m); \
_init_py_Subtensor(m); \
_init_py_TQT(m); \
_init_py_TensorRTRuntime(m); \
......
......@@ -1819,6 +1819,20 @@ public:
}
};
class Stack : public OpDefImplBase<Stack> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t axis = 0;
::mgb::CompNode comp_node;
Stack() = default;
Stack(int32_t axis_, ::mgb::CompNode comp_node_, std::string scope_ = {}): axis(axis_), comp_node(comp_node_) { set_scope(scope_); }
Stack(::megdnn::param::Axis packed_param_0, ::mgb::CompNode comp_node_): axis(packed_param_0.axis), comp_node(comp_node_) {}
::megdnn::param::Axis param() const {
return {axis};
}
};
class Subtensor : public OpDefImplBase<Subtensor> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
......
......@@ -1896,6 +1896,14 @@ SplitInst
.def_readwrite("axis", &Split::axis)
.def_readwrite("nsections", &Split::nsections);
py::class_<Stack, std::shared_ptr<Stack>, OpDef> StackInst(m, "Stack");
StackInst
.def(py::init<int32_t, ::mgb::CompNode, std::string>(), py::arg("axis") = 0, py::arg("comp_node"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("axis", &Stack::axis)
.def_readwrite("comp_node", &Stack::comp_node);
py::class_<Subtensor, std::shared_ptr<Subtensor>, OpDef> SubtensorInst(m, "Subtensor");
SubtensorInst
......
......@@ -296,6 +296,12 @@ def Concat: MgbHashableOp<"Concat", [AxisParam]> {
);
}
def Stack: MgbHashableOp<"Stack", [AxisParam]> {
let extraArguments = (ins
MgbCompNodeAttr:$comp_node
);
}
def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]> {
let extraArguments = (ins
MgbArrayAttr<MgbI32Attr>:$shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册