提交 3116e9f7 编写于 作者: M Megvii Engine Team

feat(mgb): support get opr param json from serailzed param

GitOrigin-RevId: f9f0dcbde103892f59f5679d483e90467ad7b5a7
上级 634f92fe
......@@ -5,6 +5,7 @@
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/layer_norm.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/pooling.h"
......@@ -13,6 +14,7 @@
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/indexing.h"
#include "megbrain/opr/internal/indexing_helper.h"
#include "megbrain/opr/internal/indexing_helper_sereg.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/nn_int.h"
......@@ -20,6 +22,7 @@
#include "megbrain/opr/standalone/nms_opr.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/serialization/opr_load_dump.h"
#if MGB_ENABLE_JSON
#include "megdnn/opr_param_json.h"
#endif
......@@ -488,12 +491,24 @@ uint64_t opr_footprint_func<opr::Host2DeviceCopy>(cg::OperatorNodeBase* opr) {
template <class T>
std::shared_ptr<json::Value> opr_param_json_func(cg::OperatorNodeBase* opr);
template <class T>
std::shared_ptr<json::Value> serial_param_json_func(
serialization::OprLoadContextRawPOD& context);
#define REGISTE_SERIAL_PARAM_JSON_FUNC(cls) \
template <> \
std::shared_ptr<json::Value> serial_param_json_func<opr::cls>( \
serialization::OprLoadContextRawPOD & context) { \
return opr::opr_param_to_json(context.read_param<opr::cls::Param>()); \
}
#define REGISTE_PARAM_JSON_FUNC(cls) \
template <> \
std::shared_ptr<json::Value> opr_param_json_func<opr::cls>( \
cg::OperatorNodeBase * opr) { \
return opr::opr_param_to_json(opr->cast_final_safe<opr::cls>().param()); \
}
} \
REGISTE_SERIAL_PARAM_JSON_FUNC(cls)
REGISTE_PARAM_JSON_FUNC(Elemwise)
REGISTE_PARAM_JSON_FUNC(ConvolutionForward)
......@@ -544,12 +559,12 @@ REGISTE_PARAM_JSON_FUNC(GaussianRNG)
REGISTE_PARAM_JSON_FUNC(Linspace)
REGISTE_PARAM_JSON_FUNC(Eye)
REGISTE_PARAM_JSON_FUNC(CvtColor)
REGISTE_PARAM_JSON_FUNC(LayerNormBackward)
REGISTE_PARAM_JSON_FUNC(AdaptivePoolingBackward)
REGISTE_PARAM_JSON_FUNC(DropoutBackward)
template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::Dimshuffle>(
cg::OperatorNodeBase* opr) {
auto param = opr->cast_final_safe<opr::Dimshuffle>().param();
std::shared_ptr<json::Value> dimshuffle_param2json(
const opr::Dimshuffle::Param& param) {
auto pattern = json::Array::make();
for (size_t i = 0; i < param.pattern_len; i++)
pattern->add(json::NumberInt::make(param.pattern[i]));
......@@ -561,10 +576,19 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::Dimshuffle>(
}
template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::AxisAddRemove>(
std::shared_ptr<json::Value> opr_param_json_func<opr::Dimshuffle>(
cg::OperatorNodeBase* opr) {
auto param = opr->cast_final_safe<opr::AxisAddRemove>().param();
auto param = opr->cast_final_safe<opr::Dimshuffle>().param();
return dimshuffle_param2json(param);
}
template <>
std::shared_ptr<json::Value> serial_param_json_func<opr::Dimshuffle>(
serialization::OprLoadContextRawPOD& context) {
return dimshuffle_param2json(context.read_param<opr::Dimshuffle::Param>());
}
std::shared_ptr<json::Value> axis_add_remove_param2json(
const opr::AxisAddRemove::Param& param) {
auto desc = json::Array::make();
for (size_t i = 0; i < param.nr_desc; i++) {
auto axisdesc = param.desc[i];
......@@ -581,6 +605,19 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::AxisAddRemove>(
});
}
template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::AxisAddRemove>(
cg::OperatorNodeBase* opr) {
auto param = opr->cast_final_safe<opr::AxisAddRemove>().param();
return axis_add_remove_param2json(param);
}
template <>
std::shared_ptr<json::Value> serial_param_json_func<opr::AxisAddRemove>(
serialization::OprLoadContextRawPOD& context) {
return axis_add_remove_param2json(context.read_param<opr::AxisAddRemove::Param>());
}
std::shared_ptr<json::Value> indexing_param_to_json(
const std::vector<opr::indexing::AxisIndexer>& indices) {
auto desc = json::Array::make();
......@@ -596,12 +633,29 @@ std::shared_ptr<json::Value> indexing_param_to_json(
return desc;
}
#define REGISTE_INDEXING_PARAM_JSON_FUNC(cls) \
template <> \
std::shared_ptr<json::Value> opr_param_json_func<opr::cls>( \
cg::OperatorNodeBase * opr) { \
auto indices = opr->cast_final_safe<opr::cls>().index_desc(); \
return indexing_param_to_json(indices); \
#define REGISTE_INDEXING_PARAM_JSON_FUNC(cls) \
template <> \
std::shared_ptr<json::Value> opr_param_json_func<opr::cls>( \
cg::OperatorNodeBase * opr) { \
auto indices = opr->cast_final_safe<opr::cls>().index_desc(); \
return indexing_param_to_json(indices); \
} \
template <> \
std::shared_ptr<json::Value> serial_param_json_func<opr::cls>( \
serialization::OprLoadContextRawPOD & context) { \
auto indices = context.read_param<serialization::IndexDescMaskDump>(); \
auto desc = json::Array::make(); \
for (size_t i = 0; i < indices.nr_item; i++) { \
auto&& index = indices.items[i]; \
desc->add(json::Object::make({ \
{"axis", json::NumberInt::make(index.axis)}, \
{"begin", json::NumberInt::make(index.begin)}, \
{"end", json::NumberInt::make(index.end)}, \
{"step", json::NumberInt::make(index.step)}, \
{"idx", json::NumberInt::make(index.idx)}, \
})); \
} \
return desc; \
}
REGISTE_INDEXING_PARAM_JSON_FUNC(Subtensor);
......@@ -617,14 +671,11 @@ REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedMeshIndexing);
REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedIncrMeshIndexing);
REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedSetMeshIndexing);
template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::Reshape>(
cg::OperatorNodeBase* opr) {
std::shared_ptr<json::Value> reshape_param2json(const opr::Reshape::Param& param) {
auto desc = json::Array::make();
auto axis_param = opr->cast_final_safe<opr::Reshape>().param();
if (axis_param.axis != axis_param.MAX_NDIM) {
if (param.axis != param.MAX_NDIM) {
return json::Object::make({
{"axis", json::NumberInt::make(axis_param.axis)},
{"axis", json::NumberInt::make(param.axis)},
});
} else {
return json::Object::make();
......@@ -632,13 +683,24 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::Reshape>(
}
template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::GetVarShape>(
std::shared_ptr<json::Value> opr_param_json_func<opr::Reshape>(
cg::OperatorNodeBase* opr) {
auto axis_param = opr->cast_final_safe<opr::Reshape>().param();
return reshape_param2json(axis_param);
}
template <>
std::shared_ptr<json::Value> serial_param_json_func<opr::Reshape>(
serialization::OprLoadContextRawPOD& context) {
return reshape_param2json(context.read_param<opr::Reshape::Param>());
}
std::shared_ptr<json::Value> getvarshape_param2json(
const opr::GetVarShape::Param& param) {
auto desc = json::Array::make();
auto axis_param = opr->cast_final_safe<opr::GetVarShape>().param();
if (axis_param.axis != axis_param.MAX_NDIM) {
if (param.axis != param.MAX_NDIM) {
return json::Object::make({
{"axis", json::NumberInt::make(axis_param.axis)},
{"axis", json::NumberInt::make(param.axis)},
});
} else {
return json::Object::make();
......@@ -646,15 +708,39 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::GetVarShape>(
}
template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::standalone::NMSKeep>(
std::shared_ptr<json::Value> opr_param_json_func<opr::GetVarShape>(
cg::OperatorNodeBase* opr) {
auto nms_param = opr->cast_final_safe<opr::standalone::NMSKeep>().param();
auto axis_param = opr->cast_final_safe<opr::GetVarShape>().param();
return getvarshape_param2json(axis_param);
}
template <>
std::shared_ptr<json::Value> serial_param_json_func<opr::GetVarShape>(
serialization::OprLoadContextRawPOD& context) {
return getvarshape_param2json(context.read_param<opr::GetVarShape::Param>());
}
std::shared_ptr<json::Value> nmskeep_param2json(
const opr::standalone::NMSKeep::Param& param) {
return json::Object::make({
{"iou_thresh", json::Number::make(nms_param.iou_thresh)},
{"max_output", json::Number::make(nms_param.max_output)},
{"iou_thresh", json::Number::make(param.iou_thresh)},
{"max_output", json::Number::make(param.max_output)},
});
}
template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::standalone::NMSKeep>(
cg::OperatorNodeBase* opr) {
auto nms_param = opr->cast_final_safe<opr::standalone::NMSKeep>().param();
return nmskeep_param2json(nms_param);
}
template <>
std::shared_ptr<json::Value> serial_param_json_func<opr::standalone::NMSKeep>(
serialization::OprLoadContextRawPOD& context) {
return nmskeep_param2json(context.read_param<opr::standalone::NMSKeep::Param>());
}
#endif // MGB_ENABLE_JSON
} // namespace
......@@ -675,6 +761,9 @@ void OprFootprint::add_single_param_json() {
auto&& record = m_type2param_json.emplace(
OprType::typeinfo(), opr_param_json_func<OprType>);
mgb_assert(record.second, "duplicate opr typeinfo");
auto&& record1 = m_type2serialparam_json.emplace(
OprType::typeinfo(), serial_param_json_func<OprType>);
mgb_assert(record1.second, "duplicate opr typeinfo");
}
#endif
......@@ -767,6 +856,9 @@ void OprFootprint::init_all_footprints() {
add_single_param_json<opr::Eye>();
add_single_param_json<opr::standalone::NMSKeep>();
add_single_param_json<opr::CvtColor>();
add_single_param_json<opr::LayerNormBackward>();
add_single_param_json<opr::AdaptivePoolingBackward>();
add_single_param_json<opr::DropoutBackward>();
#endif
}
......@@ -814,6 +906,15 @@ std::shared_ptr<json::Value> OprFootprint::get_param_json(cg::OperatorNodeBase*
return json::Object::make();
}
std::shared_ptr<json::Value> OprFootprint::get_serial_param_json(
Typeinfo* type, serialization::OprLoadContextRawPOD& context) {
auto param_trait = m_type2serialparam_json.find(type);
if (param_trait != m_type2serialparam_json.end()) {
return (param_trait->second)(context);
}
return json::Object::make();
}
std::shared_ptr<json::Value> OprFootprint::Result::to_json() const {
using namespace json;
std::shared_ptr<Value> comp;
......
#pragma once
#include "megbrain/graph.h"
#include "megbrain/serialization/opr_load_dump.h"
namespace mgb {
......@@ -14,7 +15,10 @@ class OprFootprint {
#if MGB_ENABLE_JSON
using ParamJsonTrait =
thin_function<std::shared_ptr<json::Value>(cg::OperatorNodeBase*)>;
using SerialParamJsonTrait = thin_function<std::shared_ptr<json::Value>(
serialization::OprLoadContextRawPOD&)>;
ThinHashMap<Typeinfo*, ParamJsonTrait> m_type2param_json;
ThinHashMap<Typeinfo*, SerialParamJsonTrait> m_type2serialparam_json;
#endif
//! add single footprint calculator for associated opr type.
......@@ -70,6 +74,8 @@ public:
#if MGB_ENABLE_JSON
MGE_WIN_DECLSPEC_FUC std::shared_ptr<json::Value> get_param_json(
cg::OperatorNodeBase* opr);
MGE_WIN_DECLSPEC_FUC std::shared_ptr<json::Value> get_serial_param_json(
Typeinfo* type, serialization::OprLoadContextRawPOD& context);
//! get opr foot print and graph exec info
//! the function will recompile graph, AsyncExecutable compiled before will
//! be invalid
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册