提交 fa671d67 编写于 作者: M Megvii Engine Team

feat(mgb): support get opr param json from serailzed param

GitOrigin-RevId: c4cabb6f700b722e61c5944bcd16cceab13f513d
上级 06886fd1
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/images2neibs.h" #include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/layer_norm.h"
#include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h" #include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/dnn/pooling.h"
...@@ -13,6 +14,7 @@ ...@@ -13,6 +14,7 @@
#include "megbrain/opr/imgproc.h" #include "megbrain/opr/imgproc.h"
#include "megbrain/opr/indexing.h" #include "megbrain/opr/indexing.h"
#include "megbrain/opr/internal/indexing_helper.h" #include "megbrain/opr/internal/indexing_helper.h"
#include "megbrain/opr/internal/indexing_helper_sereg.h"
#include "megbrain/opr/io.h" #include "megbrain/opr/io.h"
#include "megbrain/opr/misc.h" #include "megbrain/opr/misc.h"
#include "megbrain/opr/nn_int.h" #include "megbrain/opr/nn_int.h"
...@@ -20,6 +22,7 @@ ...@@ -20,6 +22,7 @@
#include "megbrain/opr/standalone/nms_opr.h" #include "megbrain/opr/standalone/nms_opr.h"
#include "megbrain/opr/tensor_gen.h" #include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/tensor_manip.h"
#include "megbrain/serialization/opr_load_dump.h"
#if MGB_ENABLE_JSON #if MGB_ENABLE_JSON
#include "megdnn/opr_param_json.h" #include "megdnn/opr_param_json.h"
#endif #endif
...@@ -488,12 +491,24 @@ uint64_t opr_footprint_func<opr::Host2DeviceCopy>(cg::OperatorNodeBase* opr) { ...@@ -488,12 +491,24 @@ uint64_t opr_footprint_func<opr::Host2DeviceCopy>(cg::OperatorNodeBase* opr) {
template <class T> template <class T>
std::shared_ptr<json::Value> opr_param_json_func(cg::OperatorNodeBase* opr); std::shared_ptr<json::Value> opr_param_json_func(cg::OperatorNodeBase* opr);
template <class T>
std::shared_ptr<json::Value> serial_param_json_func(
serialization::OprLoadContextRawPOD& context);
#define REGISTE_SERIAL_PARAM_JSON_FUNC(cls) \
template <> \
std::shared_ptr<json::Value> serial_param_json_func<opr::cls>( \
serialization::OprLoadContextRawPOD & context) { \
return opr::opr_param_to_json(context.read_param<opr::cls::Param>()); \
}
#define REGISTE_PARAM_JSON_FUNC(cls) \ #define REGISTE_PARAM_JSON_FUNC(cls) \
template <> \ template <> \
std::shared_ptr<json::Value> opr_param_json_func<opr::cls>( \ std::shared_ptr<json::Value> opr_param_json_func<opr::cls>( \
cg::OperatorNodeBase * opr) { \ cg::OperatorNodeBase * opr) { \
return opr::opr_param_to_json(opr->cast_final_safe<opr::cls>().param()); \ return opr::opr_param_to_json(opr->cast_final_safe<opr::cls>().param()); \
} } \
REGISTE_SERIAL_PARAM_JSON_FUNC(cls)
REGISTE_PARAM_JSON_FUNC(Elemwise) REGISTE_PARAM_JSON_FUNC(Elemwise)
REGISTE_PARAM_JSON_FUNC(ConvolutionForward) REGISTE_PARAM_JSON_FUNC(ConvolutionForward)
...@@ -544,12 +559,12 @@ REGISTE_PARAM_JSON_FUNC(GaussianRNG) ...@@ -544,12 +559,12 @@ REGISTE_PARAM_JSON_FUNC(GaussianRNG)
REGISTE_PARAM_JSON_FUNC(Linspace) REGISTE_PARAM_JSON_FUNC(Linspace)
REGISTE_PARAM_JSON_FUNC(Eye) REGISTE_PARAM_JSON_FUNC(Eye)
REGISTE_PARAM_JSON_FUNC(CvtColor) REGISTE_PARAM_JSON_FUNC(CvtColor)
REGISTE_PARAM_JSON_FUNC(LayerNormBackward)
REGISTE_PARAM_JSON_FUNC(AdaptivePoolingBackward)
REGISTE_PARAM_JSON_FUNC(DropoutBackward)
template <> std::shared_ptr<json::Value> dimshuffle_param2json(
std::shared_ptr<json::Value> opr_param_json_func<opr::Dimshuffle>( const opr::Dimshuffle::Param& param) {
cg::OperatorNodeBase* opr) {
auto param = opr->cast_final_safe<opr::Dimshuffle>().param();
auto pattern = json::Array::make(); auto pattern = json::Array::make();
for (size_t i = 0; i < param.pattern_len; i++) for (size_t i = 0; i < param.pattern_len; i++)
pattern->add(json::NumberInt::make(param.pattern[i])); pattern->add(json::NumberInt::make(param.pattern[i]));
...@@ -561,10 +576,19 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::Dimshuffle>( ...@@ -561,10 +576,19 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::Dimshuffle>(
} }
template <> template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::AxisAddRemove>( std::shared_ptr<json::Value> opr_param_json_func<opr::Dimshuffle>(
cg::OperatorNodeBase* opr) { cg::OperatorNodeBase* opr) {
auto param = opr->cast_final_safe<opr::AxisAddRemove>().param(); auto param = opr->cast_final_safe<opr::Dimshuffle>().param();
return dimshuffle_param2json(param);
}
template <>
std::shared_ptr<json::Value> serial_param_json_func<opr::Dimshuffle>(
serialization::OprLoadContextRawPOD& context) {
return dimshuffle_param2json(context.read_param<opr::Dimshuffle::Param>());
}
std::shared_ptr<json::Value> axis_add_remove_param2json(
const opr::AxisAddRemove::Param& param) {
auto desc = json::Array::make(); auto desc = json::Array::make();
for (size_t i = 0; i < param.nr_desc; i++) { for (size_t i = 0; i < param.nr_desc; i++) {
auto axisdesc = param.desc[i]; auto axisdesc = param.desc[i];
...@@ -581,6 +605,19 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::AxisAddRemove>( ...@@ -581,6 +605,19 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::AxisAddRemove>(
}); });
} }
template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::AxisAddRemove>(
cg::OperatorNodeBase* opr) {
auto param = opr->cast_final_safe<opr::AxisAddRemove>().param();
return axis_add_remove_param2json(param);
}
template <>
std::shared_ptr<json::Value> serial_param_json_func<opr::AxisAddRemove>(
serialization::OprLoadContextRawPOD& context) {
return axis_add_remove_param2json(context.read_param<opr::AxisAddRemove::Param>());
}
std::shared_ptr<json::Value> indexing_param_to_json( std::shared_ptr<json::Value> indexing_param_to_json(
const std::vector<opr::indexing::AxisIndexer>& indices) { const std::vector<opr::indexing::AxisIndexer>& indices) {
auto desc = json::Array::make(); auto desc = json::Array::make();
...@@ -596,12 +633,29 @@ std::shared_ptr<json::Value> indexing_param_to_json( ...@@ -596,12 +633,29 @@ std::shared_ptr<json::Value> indexing_param_to_json(
return desc; return desc;
} }
#define REGISTE_INDEXING_PARAM_JSON_FUNC(cls) \ #define REGISTE_INDEXING_PARAM_JSON_FUNC(cls) \
template <> \ template <> \
std::shared_ptr<json::Value> opr_param_json_func<opr::cls>( \ std::shared_ptr<json::Value> opr_param_json_func<opr::cls>( \
cg::OperatorNodeBase * opr) { \ cg::OperatorNodeBase * opr) { \
auto indices = opr->cast_final_safe<opr::cls>().index_desc(); \ auto indices = opr->cast_final_safe<opr::cls>().index_desc(); \
return indexing_param_to_json(indices); \ return indexing_param_to_json(indices); \
} \
template <> \
std::shared_ptr<json::Value> serial_param_json_func<opr::cls>( \
serialization::OprLoadContextRawPOD & context) { \
auto indices = context.read_param<serialization::IndexDescMaskDump>(); \
auto desc = json::Array::make(); \
for (size_t i = 0; i < indices.nr_item; i++) { \
auto&& index = indices.items[i]; \
desc->add(json::Object::make({ \
{"axis", json::NumberInt::make(index.axis)}, \
{"begin", json::NumberInt::make(index.begin)}, \
{"end", json::NumberInt::make(index.end)}, \
{"step", json::NumberInt::make(index.step)}, \
{"idx", json::NumberInt::make(index.idx)}, \
})); \
} \
return desc; \
} }
REGISTE_INDEXING_PARAM_JSON_FUNC(Subtensor); REGISTE_INDEXING_PARAM_JSON_FUNC(Subtensor);
...@@ -617,14 +671,11 @@ REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedMeshIndexing); ...@@ -617,14 +671,11 @@ REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedMeshIndexing);
REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedIncrMeshIndexing); REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedIncrMeshIndexing);
REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedSetMeshIndexing); REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedSetMeshIndexing);
template <> std::shared_ptr<json::Value> reshape_param2json(const opr::Reshape::Param& param) {
std::shared_ptr<json::Value> opr_param_json_func<opr::Reshape>(
cg::OperatorNodeBase* opr) {
auto desc = json::Array::make(); auto desc = json::Array::make();
auto axis_param = opr->cast_final_safe<opr::Reshape>().param(); if (param.axis != param.MAX_NDIM) {
if (axis_param.axis != axis_param.MAX_NDIM) {
return json::Object::make({ return json::Object::make({
{"axis", json::NumberInt::make(axis_param.axis)}, {"axis", json::NumberInt::make(param.axis)},
}); });
} else { } else {
return json::Object::make(); return json::Object::make();
...@@ -632,13 +683,24 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::Reshape>( ...@@ -632,13 +683,24 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::Reshape>(
} }
template <> template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::GetVarShape>( std::shared_ptr<json::Value> opr_param_json_func<opr::Reshape>(
cg::OperatorNodeBase* opr) { cg::OperatorNodeBase* opr) {
auto axis_param = opr->cast_final_safe<opr::Reshape>().param();
return reshape_param2json(axis_param);
}
template <>
std::shared_ptr<json::Value> serial_param_json_func<opr::Reshape>(
serialization::OprLoadContextRawPOD& context) {
return reshape_param2json(context.read_param<opr::Reshape::Param>());
}
std::shared_ptr<json::Value> getvarshape_param2json(
const opr::GetVarShape::Param& param) {
auto desc = json::Array::make(); auto desc = json::Array::make();
auto axis_param = opr->cast_final_safe<opr::GetVarShape>().param(); if (param.axis != param.MAX_NDIM) {
if (axis_param.axis != axis_param.MAX_NDIM) {
return json::Object::make({ return json::Object::make({
{"axis", json::NumberInt::make(axis_param.axis)}, {"axis", json::NumberInt::make(param.axis)},
}); });
} else { } else {
return json::Object::make(); return json::Object::make();
...@@ -646,15 +708,39 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::GetVarShape>( ...@@ -646,15 +708,39 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::GetVarShape>(
} }
template <> template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::standalone::NMSKeep>( std::shared_ptr<json::Value> opr_param_json_func<opr::GetVarShape>(
cg::OperatorNodeBase* opr) { cg::OperatorNodeBase* opr) {
auto nms_param = opr->cast_final_safe<opr::standalone::NMSKeep>().param(); auto axis_param = opr->cast_final_safe<opr::GetVarShape>().param();
return getvarshape_param2json(axis_param);
}
template <>
std::shared_ptr<json::Value> serial_param_json_func<opr::GetVarShape>(
serialization::OprLoadContextRawPOD& context) {
return getvarshape_param2json(context.read_param<opr::GetVarShape::Param>());
}
std::shared_ptr<json::Value> nmskeep_param2json(
const opr::standalone::NMSKeep::Param& param) {
return json::Object::make({ return json::Object::make({
{"iou_thresh", json::Number::make(nms_param.iou_thresh)}, {"iou_thresh", json::Number::make(param.iou_thresh)},
{"max_output", json::Number::make(nms_param.max_output)}, {"max_output", json::Number::make(param.max_output)},
}); });
} }
template <>
std::shared_ptr<json::Value> opr_param_json_func<opr::standalone::NMSKeep>(
cg::OperatorNodeBase* opr) {
auto nms_param = opr->cast_final_safe<opr::standalone::NMSKeep>().param();
return nmskeep_param2json(nms_param);
}
template <>
std::shared_ptr<json::Value> serial_param_json_func<opr::standalone::NMSKeep>(
serialization::OprLoadContextRawPOD& context) {
return nmskeep_param2json(context.read_param<opr::standalone::NMSKeep::Param>());
}
#endif // MGB_ENABLE_JSON #endif // MGB_ENABLE_JSON
} // namespace } // namespace
...@@ -675,6 +761,9 @@ void OprFootprint::add_single_param_json() { ...@@ -675,6 +761,9 @@ void OprFootprint::add_single_param_json() {
auto&& record = m_type2param_json.emplace( auto&& record = m_type2param_json.emplace(
OprType::typeinfo(), opr_param_json_func<OprType>); OprType::typeinfo(), opr_param_json_func<OprType>);
mgb_assert(record.second, "duplicate opr typeinfo"); mgb_assert(record.second, "duplicate opr typeinfo");
auto&& record1 = m_type2serialparam_json.emplace(
OprType::typeinfo(), serial_param_json_func<OprType>);
mgb_assert(record1.second, "duplicate opr typeinfo");
} }
#endif #endif
...@@ -767,6 +856,9 @@ void OprFootprint::init_all_footprints() { ...@@ -767,6 +856,9 @@ void OprFootprint::init_all_footprints() {
add_single_param_json<opr::Eye>(); add_single_param_json<opr::Eye>();
add_single_param_json<opr::standalone::NMSKeep>(); add_single_param_json<opr::standalone::NMSKeep>();
add_single_param_json<opr::CvtColor>(); add_single_param_json<opr::CvtColor>();
add_single_param_json<opr::LayerNormBackward>();
add_single_param_json<opr::AdaptivePoolingBackward>();
add_single_param_json<opr::DropoutBackward>();
#endif #endif
} }
...@@ -814,6 +906,15 @@ std::shared_ptr<json::Value> OprFootprint::get_param_json(cg::OperatorNodeBase* ...@@ -814,6 +906,15 @@ std::shared_ptr<json::Value> OprFootprint::get_param_json(cg::OperatorNodeBase*
return json::Object::make(); return json::Object::make();
} }
std::shared_ptr<json::Value> OprFootprint::get_serial_param_json(
Typeinfo* type, serialization::OprLoadContextRawPOD& context) {
auto param_trait = m_type2serialparam_json.find(type);
if (param_trait != m_type2serialparam_json.end()) {
return (param_trait->second)(context);
}
return json::Object::make();
}
std::shared_ptr<json::Value> OprFootprint::Result::to_json() const { std::shared_ptr<json::Value> OprFootprint::Result::to_json() const {
using namespace json; using namespace json;
std::shared_ptr<Value> comp; std::shared_ptr<Value> comp;
......
#pragma once #pragma once
#include "megbrain/graph.h" #include "megbrain/graph.h"
#include "megbrain/serialization/opr_load_dump.h"
namespace mgb { namespace mgb {
...@@ -14,7 +15,10 @@ class OprFootprint { ...@@ -14,7 +15,10 @@ class OprFootprint {
#if MGB_ENABLE_JSON #if MGB_ENABLE_JSON
using ParamJsonTrait = using ParamJsonTrait =
thin_function<std::shared_ptr<json::Value>(cg::OperatorNodeBase*)>; thin_function<std::shared_ptr<json::Value>(cg::OperatorNodeBase*)>;
using SerialParamJsonTrait = thin_function<std::shared_ptr<json::Value>(
serialization::OprLoadContextRawPOD&)>;
ThinHashMap<Typeinfo*, ParamJsonTrait> m_type2param_json; ThinHashMap<Typeinfo*, ParamJsonTrait> m_type2param_json;
ThinHashMap<Typeinfo*, SerialParamJsonTrait> m_type2serialparam_json;
#endif #endif
//! add single footprint calculator for associated opr type. //! add single footprint calculator for associated opr type.
...@@ -70,6 +74,8 @@ public: ...@@ -70,6 +74,8 @@ public:
#if MGB_ENABLE_JSON #if MGB_ENABLE_JSON
MGE_WIN_DECLSPEC_FUC std::shared_ptr<json::Value> get_param_json( MGE_WIN_DECLSPEC_FUC std::shared_ptr<json::Value> get_param_json(
cg::OperatorNodeBase* opr); cg::OperatorNodeBase* opr);
MGE_WIN_DECLSPEC_FUC std::shared_ptr<json::Value> get_serial_param_json(
Typeinfo* type, serialization::OprLoadContextRawPOD& context);
//! get opr foot print and graph exec info //! get opr foot print and graph exec info
//! the function will recompile graph, AsyncExecutable compiled before will //! the function will recompile graph, AsyncExecutable compiled before will
//! be invalid //! be invalid
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册