提交 af29fcb2 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(mgb/plugin): add param json func for indexing oprs

GitOrigin-RevId: b5becbbc028b16e4e1dcf313a1acf1c8b5bb11d8
上级 62753c4d
......@@ -512,6 +512,61 @@ REGISTE_PARAM_JSON_FUNC(DeformableConvBackwardFilter)
REGISTE_PARAM_JSON_FUNC(DeformableConvBackwardData)
REGISTE_PARAM_JSON_FUNC(BatchConvBiasForward)
template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::Dimshuffle>(
cg::OperatorNodeBase * opr) {
auto param = opr->cast_final_safe<opr::Dimshuffle>().param();
auto pattern = json::Array::make();
for (size_t i = 0; i < param.pattern_len; i++)
pattern->add(json::NumberInt::make(param.pattern[i]));
return json::Object::make({
{"ndim", json::NumberInt::make(param.ndim)},
{"pattern", pattern},
});
}
template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::AxisAddRemove>(
cg::OperatorNodeBase * opr) {
auto param = opr->cast_final_safe<opr::AxisAddRemove>().param();
auto desc = json::Array::make();
for (size_t i = 0; i < param.nr_desc; i++) {
auto axisdesc = param.desc[i];
desc->add(
json::Object::make({
{"method", json::NumberInt::make(
static_cast<int32_t>(axisdesc.method))},
{"axisnum", json::NumberInt::make(axisdesc.axis.get_raw())},
}));
}
return json::Object::make({
{"nr_desc", json::NumberInt::make(param.nr_desc)},
{"desc", desc},
});
}
template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::Subtensor>(
cg::OperatorNodeBase * opr) {
auto desc = json::Array::make();
auto indices = opr->cast_final_safe<opr::Subtensor>().index_desc();
for (auto &index : indices){
desc->add(
json::Object::make({
{"axis", json::NumberInt::make(index.axis.get_raw())},
{"begin", json::NumberInt::make(index.begin.node() != nullptr)},
{"end", json::NumberInt::make(index.end.node() != nullptr)},
{"step", json::NumberInt::make(index.step.node() != nullptr)},
{"idx", json::NumberInt::make(index.idx.node() != nullptr)},
}));
}
return desc;
}
#endif // MGB_ENABLE_JSON
} // namespace
......@@ -573,6 +628,9 @@ void OprFootprint::init_all_footprints() {
add_single_param_json<opr::GroupLocal>();
add_single_param_json<opr::LRN>();
add_single_param_json<opr::Concat>();
add_single_param_json<opr::Dimshuffle>();
add_single_param_json<opr::AxisAddRemove>();
add_single_param_json<opr::Subtensor>();
add_single_param_json<opr::Reduce>();
add_single_param_json<opr::LocalShareForward>();
add_single_param_json<opr::LocalShareBackwardData>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册