diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 5d329f19774ee5d8ba8f3283a182f0f409e647fa..1606f611f5a8d63a63eb3c627ee663423a95925d 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -73,13 +73,11 @@ void Converter::FreeFuncGraph(const FuncGraphPtr &func_graph) { return; } if (primT->value.type == schema::PrimitiveType_TupleGetItem || - primT->value.type == schema::PrimitiveType_MakeTuple || - primT->value.type == schema::PrimitiveType_Return) { + primT->value.type == schema::PrimitiveType_MakeTuple || primT->value.type == schema::PrimitiveType_Return) { delete primT; primitiveT_value->SetPrimitiveT(nullptr); } } - return; } MetaGraphT *Converter::Convert(const converter::Flags *flag) { // parse the model and weight file to generate inference data structure diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 95f8a48810588aa3d9ac66a4e018d3e1893e2a9f..3032ce0908d645abc068abf470aec61c4a79a9b5 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -93,7 +93,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { return false; } - auto type = primitiveT_value->GetPrimitiveT()->value.type; + auto type = (schema::PrimitiveType)primitiveT_value->Type(); MS_LOG(INFO) << "Primitive type: " << type; static const std::vector uint8OpList = { schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index c4c1aace5b832ac8515ba8c9ca6485cfad6781bd..e444b6dc4bb0881d30681579383ae8a77e5991f3 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -170,7 +170,7 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) { if (a.m_ptr->isa() && b.m_ptr->isa()) { auto a_value_node_ptr = a.m_ptr->cast(); auto b_value_node_ptr = b.m_ptr->cast(); - return a_value_node_ptr->GetPrimitiveT()->value.type == b_value_node_ptr->GetPrimitiveT()->value.type; + return a_value_node_ptr->Type() == b_value_node_ptr->Type(); } return a == b; @@ -316,7 +316,7 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) { if (utils::isa(value)) { auto primitive = value->cast(); MS_ASSERT(primitive != nullptr); - return primitive->GetPrimitiveT()->value.type; + return (schema::PrimitiveType)primitive->Type(); } else if (utils::isa(value)) { auto primitive = value->cast(); MS_ASSERT(primitive != nullptr); diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index 5043df1c92c7c4faa6e4046a7c6faf9fa8d48bfe..49d4af2cae9a9ff919b82a4beaea5ef2e754979d 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -73,26 +73,6 @@ const std::vector GetCNodeInputTensors(const CNodePtr &CNode) { } return input_tensors; } -schema::Primitive *PackPrimitiveT(const CNodePtr &cnode) { - auto primitiveT_value = GetValueNode>(cnode->input(0)); - if (primitiveT_value == nullptr) { - MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; - return nullptr; - } - - auto *lite_primitive = primitiveT_value->GetPrimitiveT(); - if (lite_primitive == nullptr) { - MS_LOG(ERROR) << "Primitive in primitiveT_value is nullptr"; - return nullptr; - } - - flatbuffers::FlatBufferBuilder builder(1024); - auto offset = schema::Primitive::Pack(builder, lite_primitive); - builder.Finish(offset); - auto buf = builder.GetBufferPointer(); - auto primitive = flatbuffers::GetRoot(buf); - return const_cast(primitive); -} const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { auto parameter = func_graph->add_parameter(); std::vector shape; @@ -175,16 +155,10 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An } MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope(); auto output_nums = GetOutputTensorNum(input_cnode); + auto primitiveT_value = GetValueNode>(input_cnode->input(0)); std::vector output_tensors{output_nums, new Tensor()}; - auto scheam_primitive = PackPrimitiveT(input_cnode); - auto lite_primitive = mindspore::lite::PrimitiveC::UnPackFromSchemaPrimitive(scheam_primitive); - if (lite_primitive == nullptr) { - MS_LOG(ERROR) << "constant_folding schedule node lite primitive nullptr"; - FreeInputTensor(&input_tensors); - return nullptr; - } - lite_primitive->InferShape(input_tensors, output_tensors); - auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive); + primitiveT_value->InferShape(input_tensors, output_tensors); + auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, primitiveT_value.get()); if (lite_kernel == nullptr) { MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; FreeInputTensor(&input_tensors); diff --git a/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc index b02eccd3fe948be197f1ccca39de43e6a31b8e72..5e391ec7a30e38f6f1e53d12154e8c3f516fe3f8 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc @@ -22,6 +22,8 @@ #include "utils/utils.h" #include "tools/optimizer/common/gllo_utils.h" #include "securec/include/securec.h" +#include "src/ops/batch_norm.h" +#include "src/ops/fused_batchnorm.h" namespace mindspore::opt { namespace { @@ -94,7 +96,7 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const { auto bn_mean_var = std::make_shared(IsParamNode); auto bn_variable_var = std::make_shared(IsParamNode); auto bn_other_var = std::make_shared(); - return VectorRef({bn_var, conv_var, bn_mean_var, bn_variable_var, bn_other_var});; + return VectorRef({bn_var, conv_var, bn_mean_var, bn_variable_var, bn_other_var}); } // BatchNorm weight Tensor definition: // caffe @@ -106,7 +108,7 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const { // estimated_mean --2 // estimated_variance --3 const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale, - float *trans_bias) const { + float *trans_bias) const { MS_ASSERT(bn_node != nullptr); AnfNodePtr bn_mean_node = nullptr; AnfNodePtr bn_variance_node = nullptr; @@ -119,13 +121,19 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern bn_variance_node = bn_node->input(kCaffeBNVarIndex); CheckIfNodeIsParam(bn_mean_node); CheckIfNodeIsParam(bn_variance_node); - eps = primitiveT_value->GetPrimitiveT()->value.AsBatchNorm()->epsilon; + MS_ASSERT(utils::isa>(primitiveT_value)); + auto primc = utils::cast>(primitiveT_value); + MS_ASSERT(primc != nullptr); + eps = primc->GetEpsilon(); } else if (GetCNodeType(bn_node) == schema::PrimitiveType_FusedBatchNorm) { bn_scale_node = bn_node->input(kTFBNScaleIndex); bn_bias_node = bn_node->input(kTFBNBiasIndex); bn_mean_node = bn_node->input(kTFBNMeanIndex); bn_variance_node = bn_node->input(kTFBNVarIndex); - eps = primitiveT_value->GetPrimitiveT()->value.AsFusedBatchNorm()->epsilon; + MS_ASSERT(utils::isa>(primitiveT_value)); + auto primc = utils::cast>(primitiveT_value); + MS_ASSERT(primc != nullptr); + eps = primc->GetEpsilon(); } else { MS_LOG(EXCEPTION) << "not caffe or tf batchnorm op."; }