From 06886fd1ce006da233406e6eddcb7692774bd4e9 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 29 Jun 2023 10:45:12 +0800 Subject: [PATCH] Revert "feat(mgb): support get opr param json from serailzed param" This reverts commit f9f0dcbde103892f59f5679d483e90467ad7b5a7. GitOrigin-RevId: 41bd87b7cd4eb8780b0a168aed11673bce1dba7a --- src/plugin/impl/opr_footprint.cpp | 157 ++++-------------- .../include/megbrain/plugin/opr_footprint.h | 6 - 2 files changed, 28 insertions(+), 135 deletions(-) diff --git a/src/plugin/impl/opr_footprint.cpp b/src/plugin/impl/opr_footprint.cpp index 3cc33f509..a483b63e0 100644 --- a/src/plugin/impl/opr_footprint.cpp +++ b/src/plugin/impl/opr_footprint.cpp @@ -5,7 +5,6 @@ #include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/images2neibs.h" -#include "megbrain/opr/dnn/layer_norm.h" #include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/lrn.h" #include "megbrain/opr/dnn/pooling.h" @@ -14,7 +13,6 @@ #include "megbrain/opr/imgproc.h" #include "megbrain/opr/indexing.h" #include "megbrain/opr/internal/indexing_helper.h" -#include "megbrain/opr/internal/indexing_helper_sereg.h" #include "megbrain/opr/io.h" #include "megbrain/opr/misc.h" #include "megbrain/opr/nn_int.h" @@ -22,7 +20,6 @@ #include "megbrain/opr/standalone/nms_opr.h" #include "megbrain/opr/tensor_gen.h" #include "megbrain/opr/tensor_manip.h" -#include "megbrain/serialization/opr_load_dump.h" #if MGB_ENABLE_JSON #include "megdnn/opr_param_json.h" #endif @@ -491,24 +488,12 @@ uint64_t opr_footprint_func(cg::OperatorNodeBase* opr) { template std::shared_ptr opr_param_json_func(cg::OperatorNodeBase* opr); -template -std::shared_ptr serial_param_json_func( - serialization::OprLoadContextRawPOD& context); - -#define REGISTE_SERIAL_PARAM_JSON_FUNC(cls) \ - template <> \ - std::shared_ptr serial_param_json_func( \ - serialization::OprLoadContextRawPOD & context) { \ - return opr::opr_param_to_json(context.read_param()); \ - } - #define REGISTE_PARAM_JSON_FUNC(cls) \ template <> \ std::shared_ptr opr_param_json_func( \ cg::OperatorNodeBase * opr) { \ return opr::opr_param_to_json(opr->cast_final_safe().param()); \ - } \ - REGISTE_SERIAL_PARAM_JSON_FUNC(cls) + } REGISTE_PARAM_JSON_FUNC(Elemwise) REGISTE_PARAM_JSON_FUNC(ConvolutionForward) @@ -559,12 +544,12 @@ REGISTE_PARAM_JSON_FUNC(GaussianRNG) REGISTE_PARAM_JSON_FUNC(Linspace) REGISTE_PARAM_JSON_FUNC(Eye) REGISTE_PARAM_JSON_FUNC(CvtColor) -REGISTE_PARAM_JSON_FUNC(LayerNormBackward) -REGISTE_PARAM_JSON_FUNC(AdaptivePoolingBackward) -REGISTE_PARAM_JSON_FUNC(DropoutBackward) -std::shared_ptr dimshuffle_param2json( - const opr::Dimshuffle::Param& param) { +template <> +std::shared_ptr opr_param_json_func( + cg::OperatorNodeBase* opr) { + auto param = opr->cast_final_safe().param(); + auto pattern = json::Array::make(); for (size_t i = 0; i < param.pattern_len; i++) pattern->add(json::NumberInt::make(param.pattern[i])); @@ -576,19 +561,10 @@ std::shared_ptr dimshuffle_param2json( } template <> -std::shared_ptr opr_param_json_func( +std::shared_ptr opr_param_json_func( cg::OperatorNodeBase* opr) { - auto param = opr->cast_final_safe().param(); - return dimshuffle_param2json(param); -} -template <> -std::shared_ptr serial_param_json_func( - serialization::OprLoadContextRawPOD& context) { - return dimshuffle_param2json(context.read_param()); -} + auto param = opr->cast_final_safe().param(); -std::shared_ptr axis_add_remove_param2json( - const opr::AxisAddRemove::Param& param) { auto desc = json::Array::make(); for (size_t i = 0; i < param.nr_desc; i++) { auto axisdesc = param.desc[i]; @@ -605,19 +581,6 @@ std::shared_ptr axis_add_remove_param2json( }); } -template <> -std::shared_ptr opr_param_json_func( - cg::OperatorNodeBase* opr) { - auto param = opr->cast_final_safe().param(); - return axis_add_remove_param2json(param); -} - -template <> -std::shared_ptr serial_param_json_func( - serialization::OprLoadContextRawPOD& context) { - return axis_add_remove_param2json(context.read_param()); -} - std::shared_ptr indexing_param_to_json( const std::vector& indices) { auto desc = json::Array::make(); @@ -633,29 +596,12 @@ std::shared_ptr indexing_param_to_json( return desc; } -#define REGISTE_INDEXING_PARAM_JSON_FUNC(cls) \ - template <> \ - std::shared_ptr opr_param_json_func( \ - cg::OperatorNodeBase * opr) { \ - auto indices = opr->cast_final_safe().index_desc(); \ - return indexing_param_to_json(indices); \ - } \ - template <> \ - std::shared_ptr serial_param_json_func( \ - serialization::OprLoadContextRawPOD & context) { \ - auto indices = context.read_param(); \ - auto desc = json::Array::make(); \ - for (size_t i = 0; i < indices.nr_item; i++) { \ - auto&& index = indices.items[i]; \ - desc->add(json::Object::make({ \ - {"axis", json::NumberInt::make(index.axis)}, \ - {"begin", json::NumberInt::make(index.begin)}, \ - {"end", json::NumberInt::make(index.end)}, \ - {"step", json::NumberInt::make(index.step)}, \ - {"idx", json::NumberInt::make(index.idx)}, \ - })); \ - } \ - return desc; \ +#define REGISTE_INDEXING_PARAM_JSON_FUNC(cls) \ + template <> \ + std::shared_ptr opr_param_json_func( \ + cg::OperatorNodeBase * opr) { \ + auto indices = opr->cast_final_safe().index_desc(); \ + return indexing_param_to_json(indices); \ } REGISTE_INDEXING_PARAM_JSON_FUNC(Subtensor); @@ -671,11 +617,14 @@ REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedMeshIndexing); REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedIncrMeshIndexing); REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedSetMeshIndexing); -std::shared_ptr reshape_param2json(const opr::Reshape::Param& param) { +template <> +std::shared_ptr opr_param_json_func( + cg::OperatorNodeBase* opr) { auto desc = json::Array::make(); - if (param.axis != param.MAX_NDIM) { + auto axis_param = opr->cast_final_safe().param(); + if (axis_param.axis != axis_param.MAX_NDIM) { return json::Object::make({ - {"axis", json::NumberInt::make(param.axis)}, + {"axis", json::NumberInt::make(axis_param.axis)}, }); } else { return json::Object::make(); @@ -683,62 +632,27 @@ std::shared_ptr reshape_param2json(const opr::Reshape::Param& param } template <> -std::shared_ptr opr_param_json_func( +std::shared_ptr opr_param_json_func( cg::OperatorNodeBase* opr) { - auto axis_param = opr->cast_final_safe().param(); - return reshape_param2json(axis_param); -} - -template <> -std::shared_ptr serial_param_json_func( - serialization::OprLoadContextRawPOD& context) { - return reshape_param2json(context.read_param()); -} - -std::shared_ptr getvarshape_param2json( - const opr::GetVarShape::Param& param) { auto desc = json::Array::make(); - if (param.axis != param.MAX_NDIM) { + auto axis_param = opr->cast_final_safe().param(); + if (axis_param.axis != axis_param.MAX_NDIM) { return json::Object::make({ - {"axis", json::NumberInt::make(param.axis)}, + {"axis", json::NumberInt::make(axis_param.axis)}, }); } else { return json::Object::make(); } } -template <> -std::shared_ptr opr_param_json_func( - cg::OperatorNodeBase* opr) { - auto axis_param = opr->cast_final_safe().param(); - return getvarshape_param2json(axis_param); -} - -template <> -std::shared_ptr serial_param_json_func( - serialization::OprLoadContextRawPOD& context) { - return getvarshape_param2json(context.read_param()); -} - -std::shared_ptr nmskeep_param2json( - const opr::standalone::NMSKeep::Param& param) { - return json::Object::make({ - {"iou_thresh", json::Number::make(param.iou_thresh)}, - {"max_output", json::Number::make(param.max_output)}, - }); -} - template <> std::shared_ptr opr_param_json_func( cg::OperatorNodeBase* opr) { auto nms_param = opr->cast_final_safe().param(); - return nmskeep_param2json(nms_param); -} - -template <> -std::shared_ptr serial_param_json_func( - serialization::OprLoadContextRawPOD& context) { - return nmskeep_param2json(context.read_param()); + return json::Object::make({ + {"iou_thresh", json::Number::make(nms_param.iou_thresh)}, + {"max_output", json::Number::make(nms_param.max_output)}, + }); } #endif // MGB_ENABLE_JSON @@ -761,9 +675,6 @@ void OprFootprint::add_single_param_json() { auto&& record = m_type2param_json.emplace( OprType::typeinfo(), opr_param_json_func); mgb_assert(record.second, "duplicate opr typeinfo"); - auto&& record1 = m_type2serialparam_json.emplace( - OprType::typeinfo(), serial_param_json_func); - mgb_assert(record1.second, "duplicate opr typeinfo"); } #endif @@ -856,9 +767,6 @@ void OprFootprint::init_all_footprints() { add_single_param_json(); add_single_param_json(); add_single_param_json(); - add_single_param_json(); - add_single_param_json(); - add_single_param_json(); #endif } @@ -906,15 +814,6 @@ std::shared_ptr OprFootprint::get_param_json(cg::OperatorNodeBase* return json::Object::make(); } -std::shared_ptr OprFootprint::get_serial_param_json( - Typeinfo* type, serialization::OprLoadContextRawPOD& context) { - auto param_trait = m_type2serialparam_json.find(type); - if (param_trait != m_type2serialparam_json.end()) { - return (param_trait->second)(context); - } - return json::Object::make(); -} - std::shared_ptr OprFootprint::Result::to_json() const { using namespace json; std::shared_ptr comp; diff --git a/src/plugin/include/megbrain/plugin/opr_footprint.h b/src/plugin/include/megbrain/plugin/opr_footprint.h index 193c1cded..48ba4ed4f 100644 --- a/src/plugin/include/megbrain/plugin/opr_footprint.h +++ b/src/plugin/include/megbrain/plugin/opr_footprint.h @@ -1,7 +1,6 @@ #pragma once #include "megbrain/graph.h" -#include "megbrain/serialization/opr_load_dump.h" namespace mgb { @@ -15,10 +14,7 @@ class OprFootprint { #if MGB_ENABLE_JSON using ParamJsonTrait = thin_function(cg::OperatorNodeBase*)>; - using SerialParamJsonTrait = thin_function( - serialization::OprLoadContextRawPOD&)>; ThinHashMap m_type2param_json; - ThinHashMap m_type2serialparam_json; #endif //! add single footprint calculator for associated opr type. @@ -74,8 +70,6 @@ public: #if MGB_ENABLE_JSON MGE_WIN_DECLSPEC_FUC std::shared_ptr get_param_json( cg::OperatorNodeBase* opr); - MGE_WIN_DECLSPEC_FUC std::shared_ptr get_serial_param_json( - Typeinfo* type, serialization::OprLoadContextRawPOD& context); //! get opr foot print and graph exec info //! the function will recompile graph, AsyncExecutable compiled before will //! be invalid -- GitLab