From 3116e9f7941646e4f43c2ed44abf851736bd8677 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 8 Jun 2023 11:30:34 +0800 Subject: [PATCH] feat(mgb): support get opr param json from serailzed param GitOrigin-RevId: f9f0dcbde103892f59f5679d483e90467ad7b5a7 --- src/plugin/impl/opr_footprint.cpp | 157 ++++++++++++++---- .../include/megbrain/plugin/opr_footprint.h | 6 + 2 files changed, 135 insertions(+), 28 deletions(-) diff --git a/src/plugin/impl/opr_footprint.cpp b/src/plugin/impl/opr_footprint.cpp index a483b63e0..3cc33f509 100644 --- a/src/plugin/impl/opr_footprint.cpp +++ b/src/plugin/impl/opr_footprint.cpp @@ -5,6 +5,7 @@ #include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/images2neibs.h" +#include "megbrain/opr/dnn/layer_norm.h" #include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/lrn.h" #include "megbrain/opr/dnn/pooling.h" @@ -13,6 +14,7 @@ #include "megbrain/opr/imgproc.h" #include "megbrain/opr/indexing.h" #include "megbrain/opr/internal/indexing_helper.h" +#include "megbrain/opr/internal/indexing_helper_sereg.h" #include "megbrain/opr/io.h" #include "megbrain/opr/misc.h" #include "megbrain/opr/nn_int.h" @@ -20,6 +22,7 @@ #include "megbrain/opr/standalone/nms_opr.h" #include "megbrain/opr/tensor_gen.h" #include "megbrain/opr/tensor_manip.h" +#include "megbrain/serialization/opr_load_dump.h" #if MGB_ENABLE_JSON #include "megdnn/opr_param_json.h" #endif @@ -488,12 +491,24 @@ uint64_t opr_footprint_func(cg::OperatorNodeBase* opr) { template std::shared_ptr opr_param_json_func(cg::OperatorNodeBase* opr); +template +std::shared_ptr serial_param_json_func( + serialization::OprLoadContextRawPOD& context); + +#define REGISTE_SERIAL_PARAM_JSON_FUNC(cls) \ + template <> \ + std::shared_ptr serial_param_json_func( \ + serialization::OprLoadContextRawPOD & context) { \ + return opr::opr_param_to_json(context.read_param()); \ + } + #define REGISTE_PARAM_JSON_FUNC(cls) \ template <> \ std::shared_ptr opr_param_json_func( \ cg::OperatorNodeBase * opr) { \ return opr::opr_param_to_json(opr->cast_final_safe().param()); \ - } + } \ + REGISTE_SERIAL_PARAM_JSON_FUNC(cls) REGISTE_PARAM_JSON_FUNC(Elemwise) REGISTE_PARAM_JSON_FUNC(ConvolutionForward) @@ -544,12 +559,12 @@ REGISTE_PARAM_JSON_FUNC(GaussianRNG) REGISTE_PARAM_JSON_FUNC(Linspace) REGISTE_PARAM_JSON_FUNC(Eye) REGISTE_PARAM_JSON_FUNC(CvtColor) +REGISTE_PARAM_JSON_FUNC(LayerNormBackward) +REGISTE_PARAM_JSON_FUNC(AdaptivePoolingBackward) +REGISTE_PARAM_JSON_FUNC(DropoutBackward) -template <> -std::shared_ptr opr_param_json_func( - cg::OperatorNodeBase* opr) { - auto param = opr->cast_final_safe().param(); - +std::shared_ptr dimshuffle_param2json( + const opr::Dimshuffle::Param& param) { auto pattern = json::Array::make(); for (size_t i = 0; i < param.pattern_len; i++) pattern->add(json::NumberInt::make(param.pattern[i])); @@ -561,10 +576,19 @@ std::shared_ptr opr_param_json_func( } template <> -std::shared_ptr opr_param_json_func( +std::shared_ptr opr_param_json_func( cg::OperatorNodeBase* opr) { - auto param = opr->cast_final_safe().param(); + auto param = opr->cast_final_safe().param(); + return dimshuffle_param2json(param); +} +template <> +std::shared_ptr serial_param_json_func( + serialization::OprLoadContextRawPOD& context) { + return dimshuffle_param2json(context.read_param()); +} +std::shared_ptr axis_add_remove_param2json( + const opr::AxisAddRemove::Param& param) { auto desc = json::Array::make(); for (size_t i = 0; i < param.nr_desc; i++) { auto axisdesc = param.desc[i]; @@ -581,6 +605,19 @@ std::shared_ptr opr_param_json_func( }); } +template <> +std::shared_ptr opr_param_json_func( + cg::OperatorNodeBase* opr) { + auto param = opr->cast_final_safe().param(); + return axis_add_remove_param2json(param); +} + +template <> +std::shared_ptr serial_param_json_func( + serialization::OprLoadContextRawPOD& context) { + return axis_add_remove_param2json(context.read_param()); +} + std::shared_ptr indexing_param_to_json( const std::vector& indices) { auto desc = json::Array::make(); @@ -596,12 +633,29 @@ std::shared_ptr indexing_param_to_json( return desc; } -#define REGISTE_INDEXING_PARAM_JSON_FUNC(cls) \ - template <> \ - std::shared_ptr opr_param_json_func( \ - cg::OperatorNodeBase * opr) { \ - auto indices = opr->cast_final_safe().index_desc(); \ - return indexing_param_to_json(indices); \ +#define REGISTE_INDEXING_PARAM_JSON_FUNC(cls) \ + template <> \ + std::shared_ptr opr_param_json_func( \ + cg::OperatorNodeBase * opr) { \ + auto indices = opr->cast_final_safe().index_desc(); \ + return indexing_param_to_json(indices); \ + } \ + template <> \ + std::shared_ptr serial_param_json_func( \ + serialization::OprLoadContextRawPOD & context) { \ + auto indices = context.read_param(); \ + auto desc = json::Array::make(); \ + for (size_t i = 0; i < indices.nr_item; i++) { \ + auto&& index = indices.items[i]; \ + desc->add(json::Object::make({ \ + {"axis", json::NumberInt::make(index.axis)}, \ + {"begin", json::NumberInt::make(index.begin)}, \ + {"end", json::NumberInt::make(index.end)}, \ + {"step", json::NumberInt::make(index.step)}, \ + {"idx", json::NumberInt::make(index.idx)}, \ + })); \ + } \ + return desc; \ } REGISTE_INDEXING_PARAM_JSON_FUNC(Subtensor); @@ -617,14 +671,11 @@ REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedMeshIndexing); REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedIncrMeshIndexing); REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedSetMeshIndexing); -template <> -std::shared_ptr opr_param_json_func( - cg::OperatorNodeBase* opr) { +std::shared_ptr reshape_param2json(const opr::Reshape::Param& param) { auto desc = json::Array::make(); - auto axis_param = opr->cast_final_safe().param(); - if (axis_param.axis != axis_param.MAX_NDIM) { + if (param.axis != param.MAX_NDIM) { return json::Object::make({ - {"axis", json::NumberInt::make(axis_param.axis)}, + {"axis", json::NumberInt::make(param.axis)}, }); } else { return json::Object::make(); @@ -632,13 +683,24 @@ std::shared_ptr opr_param_json_func( } template <> -std::shared_ptr opr_param_json_func( +std::shared_ptr opr_param_json_func( cg::OperatorNodeBase* opr) { + auto axis_param = opr->cast_final_safe().param(); + return reshape_param2json(axis_param); +} + +template <> +std::shared_ptr serial_param_json_func( + serialization::OprLoadContextRawPOD& context) { + return reshape_param2json(context.read_param()); +} + +std::shared_ptr getvarshape_param2json( + const opr::GetVarShape::Param& param) { auto desc = json::Array::make(); - auto axis_param = opr->cast_final_safe().param(); - if (axis_param.axis != axis_param.MAX_NDIM) { + if (param.axis != param.MAX_NDIM) { return json::Object::make({ - {"axis", json::NumberInt::make(axis_param.axis)}, + {"axis", json::NumberInt::make(param.axis)}, }); } else { return json::Object::make(); @@ -646,15 +708,39 @@ std::shared_ptr opr_param_json_func( } template <> -std::shared_ptr opr_param_json_func( +std::shared_ptr opr_param_json_func( cg::OperatorNodeBase* opr) { - auto nms_param = opr->cast_final_safe().param(); + auto axis_param = opr->cast_final_safe().param(); + return getvarshape_param2json(axis_param); +} + +template <> +std::shared_ptr serial_param_json_func( + serialization::OprLoadContextRawPOD& context) { + return getvarshape_param2json(context.read_param()); +} + +std::shared_ptr nmskeep_param2json( + const opr::standalone::NMSKeep::Param& param) { return json::Object::make({ - {"iou_thresh", json::Number::make(nms_param.iou_thresh)}, - {"max_output", json::Number::make(nms_param.max_output)}, + {"iou_thresh", json::Number::make(param.iou_thresh)}, + {"max_output", json::Number::make(param.max_output)}, }); } +template <> +std::shared_ptr opr_param_json_func( + cg::OperatorNodeBase* opr) { + auto nms_param = opr->cast_final_safe().param(); + return nmskeep_param2json(nms_param); +} + +template <> +std::shared_ptr serial_param_json_func( + serialization::OprLoadContextRawPOD& context) { + return nmskeep_param2json(context.read_param()); +} + #endif // MGB_ENABLE_JSON } // namespace @@ -675,6 +761,9 @@ void OprFootprint::add_single_param_json() { auto&& record = m_type2param_json.emplace( OprType::typeinfo(), opr_param_json_func); mgb_assert(record.second, "duplicate opr typeinfo"); + auto&& record1 = m_type2serialparam_json.emplace( + OprType::typeinfo(), serial_param_json_func); + mgb_assert(record1.second, "duplicate opr typeinfo"); } #endif @@ -767,6 +856,9 @@ void OprFootprint::init_all_footprints() { add_single_param_json(); add_single_param_json(); add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); #endif } @@ -814,6 +906,15 @@ std::shared_ptr OprFootprint::get_param_json(cg::OperatorNodeBase* return json::Object::make(); } +std::shared_ptr OprFootprint::get_serial_param_json( + Typeinfo* type, serialization::OprLoadContextRawPOD& context) { + auto param_trait = m_type2serialparam_json.find(type); + if (param_trait != m_type2serialparam_json.end()) { + return (param_trait->second)(context); + } + return json::Object::make(); +} + std::shared_ptr OprFootprint::Result::to_json() const { using namespace json; std::shared_ptr comp; diff --git a/src/plugin/include/megbrain/plugin/opr_footprint.h b/src/plugin/include/megbrain/plugin/opr_footprint.h index 48ba4ed4f..193c1cded 100644 --- a/src/plugin/include/megbrain/plugin/opr_footprint.h +++ b/src/plugin/include/megbrain/plugin/opr_footprint.h @@ -1,6 +1,7 @@ #pragma once #include "megbrain/graph.h" +#include "megbrain/serialization/opr_load_dump.h" namespace mgb { @@ -14,7 +15,10 @@ class OprFootprint { #if MGB_ENABLE_JSON using ParamJsonTrait = thin_function(cg::OperatorNodeBase*)>; + using SerialParamJsonTrait = thin_function( + serialization::OprLoadContextRawPOD&)>; ThinHashMap m_type2param_json; + ThinHashMap m_type2serialparam_json; #endif //! add single footprint calculator for associated opr type. @@ -70,6 +74,8 @@ public: #if MGB_ENABLE_JSON MGE_WIN_DECLSPEC_FUC std::shared_ptr get_param_json( cg::OperatorNodeBase* opr); + MGE_WIN_DECLSPEC_FUC std::shared_ptr get_serial_param_json( + Typeinfo* type, serialization::OprLoadContextRawPOD& context); //! get opr foot print and graph exec info //! the function will recompile graph, AsyncExecutable compiled before will //! be invalid -- GitLab