diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 07e802d89be48c0f05a18bbcb87452bea552c5fa..f74800bc7af6483f962ff14dd0b79db42e3def9d 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -165,14 +165,14 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) { auto cnodes = func_graph->GetOrderedCnodes(); auto meta_graphT = std::make_unique(); for (const auto &cnode : cnodes) { - auto primitiveT_value = GetValueNode>(cnode->input(0)); - if (primitiveT_value == nullptr) { - MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr"; return nullptr; } - auto primT = primitiveT_value->GetPrimitiveT(); - if (primitiveT_value->Type() == schema::PrimitiveType_TupleGetItem || - primitiveT_value->Type() == schema::PrimitiveType_MakeTuple) { + auto primT = primitive_c->GetPrimitiveT(); + if (primitive_c->Type() == schema::PrimitiveType_TupleGetItem || + primitive_c->Type() == schema::PrimitiveType_MakeTuple) { continue; } RemoveIfMakeTuple(cnode); @@ -196,7 +196,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) { return nullptr; } SetOpOutputNode(cnode, meta_graphT, node.get()); - ret = ConvertQuantParam(meta_graphT, primitiveT_value, node); + ret = ConvertQuantParam(meta_graphT, primitive_c, node); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvertQuantParam failed"; return nullptr; diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 1606f611f5a8d63a63eb3c627ee663423a95925d..75c24ca7fcf0c44e8a89e5d092ad084d19001daa 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -62,12 +62,12 @@ void Converter::FreeFuncGraph(const FuncGraphPtr &func_graph) { MS_ASSERT(func_graph != nullptr); auto cnodes = func_graph->GetOrderedCnodes(); for (auto &cnode : cnodes) { - auto primitiveT_value = GetValueNode>(cnode->input(0)); - if (primitiveT_value == nullptr) { - MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr"; return; } - auto primT = primitiveT_value->GetPrimitiveT(); + auto primT = primitive_c->GetPrimitiveT(); if (primT == nullptr) { MS_LOG(ERROR) << "PrimitiveT is nullptr"; return; @@ -75,7 +75,7 @@ void Converter::FreeFuncGraph(const FuncGraphPtr &func_graph) { if (primT->value.type == schema::PrimitiveType_TupleGetItem || primT->value.type == schema::PrimitiveType_MakeTuple || primT->value.type == schema::PrimitiveType_Return) { delete primT; - primitiveT_value->SetPrimitiveT(nullptr); + primitive_c->SetPrimitiveT(nullptr); } } } diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index e6f426a759caaac880ad7f8f0a18c2eb5fed8fe9..7a24b7728e87c1ee887eebeafae2ae0cd3a07f30 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -534,7 +534,7 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct return RET_OK; } -STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr primitiveT_value, +STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr primitive_c, bool perchanel, bool depthwise) { // const vector dims = filter->dims; // perlayer @@ -552,7 +552,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr

fullname_with_scope() << " can not get value"; return RET_ERROR; } - auto status = QuantFilter(paramValue, primitiveT_value, QuantType_PostTraining, quant_max, quant_min, bit_num, + auto status = QuantFilter(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num, perchanel, depthwise); if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed: " << status; @@ -573,8 +573,8 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr

primitiveT_value) { - if (primitiveT_value == nullptr || bias == nullptr) { +STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr primitive_c) { + if (primitive_c == nullptr || bias == nullptr) { MS_LOG(ERROR) << "null pointer!"; return RET_NULL_PTR; } @@ -583,7 +583,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptrdefault_param(); auto bias_param = std::dynamic_pointer_cast(bias_default_param); - auto active_weight_quant_params = primitiveT_value->GetInputQuantParams(); + auto active_weight_quant_params = primitive_c->GetInputQuantParams(); if (active_weight_quant_params.size() != 2) { MS_LOG(ERROR) << "unexpected active_weight_quant_params size: " << active_weight_quant_params.size(); return RET_ERROR; @@ -627,7 +627,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptrAddInputQuantParam(quant_params); + primitive_c->AddInputQuantParam(quant_params); // quant bias data int32_t *quant_datas = new (std::nothrow) int32_t[shape_size]; if (quant_datas == nullptr) { @@ -683,18 +683,18 @@ STATUS PostTrainingQuantizer::QuantNode() { MS_LOG(INFO) << cnode_name << " can not do quant"; continue; } - auto primitiveT_value = GetValueNode>(cnode->input(0)); - if (primitiveT_value == nullptr) { - MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr"; continue; } if (input_scale.find(cnode) == input_scale.end()) { - primitiveT_value->SetQuantType(schema::QuantType_QUANT_NONE); + primitive_c->SetQuantType(schema::QuantType_QUANT_NONE); continue; } - primitiveT_value->ClearInputOutputQuantParam(); + primitive_c->ClearInputOutputQuantParam(); auto op_name = cnode->fullname_with_scope(); - auto op_type = (schema::PrimitiveType)primitiveT_value->Type(); + auto op_type = (schema::PrimitiveType)primitive_c->Type(); MS_LOG(INFO) << "OpName: " << op_name; if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D && op_type != PrimitiveType_FullConnection) { @@ -715,35 +715,35 @@ STATUS PostTrainingQuantizer::QuantNode() { auto abstractTensor = utils::cast(abstractBase); if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) { MS_LOG(DEBUG) << "this parameter do quant"; - DoWeightQuant(input_node, primitiveT_value, false, false); + DoWeightQuant(input_node, primitive_c, false, false); } else { MS_LOG(DEBUG) << "this parameter no need to do quant"; } continue; } auto input_cnode = std::dynamic_pointer_cast(input_node); - auto input_cnode_primitiveT_value = GetValueNode>(input_cnode->input(0)); - if (input_cnode_primitiveT_value == nullptr) { + auto input_cnode_primitive_c = GetValueNode>(input_cnode->input(0)); + if (input_cnode_primitive_c == nullptr) { MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " << " PrimitiveC is null"; continue; } - if (!input_cnode_primitiveT_value->GetOutputQuantParams().empty()) { - for (auto &quant_param : input_cnode_primitiveT_value->GetOutputQuantParams()) { - primitiveT_value->AddInputQuantParam(quant_param); + if (!input_cnode_primitive_c->GetOutputQuantParams().empty()) { + for (auto &quant_param : input_cnode_primitive_c->GetOutputQuantParams()) { + primitive_c->AddInputQuantParam(quant_param); } } else { // do input quant double scale = input_scale[cnode]; int32_t zp = input_zero_point[cnode]; - DoQuantInput(scale, zp, &input_min_max[cnode], primitiveT_value); + DoQuantInput(scale, zp, &input_min_max[cnode], primitive_c); } } } else { // do input quant double scale = input_scale[cnode]; int32_t convInputzeropoint = input_zero_point[cnode]; - DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitiveT_value); + DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c); // do weight quant auto weight = cnode->input(2); bool depthwise = op_type == PrimitiveType_DepthwiseConv2D; @@ -751,18 +751,18 @@ STATUS PostTrainingQuantizer::QuantNode() { if (op_type == PrimitiveType_FullConnection) { perchannel = false; } - DoWeightQuant(weight, primitiveT_value, perchannel, depthwise); + DoWeightQuant(weight, primitive_c, perchannel, depthwise); // do bias quant if (cnode->inputs().size() == 4) { auto bias = cnode->input(3); - DoBiasQuant(bias, primitiveT_value); + DoBiasQuant(bias, primitive_c); } } // do output quant double OutputScale = output_scale[cnode]; int32_t OutputZeropoint = output_zeropoint[cnode]; - DoQuantOutput(OutputScale, OutputZeropoint, &output_min_max[cnode], primitiveT_value); - primitiveT_value->SetQuantType(schema::QuantType_PostTraining); + DoQuantOutput(OutputScale, OutputZeropoint, &output_min_max[cnode], primitive_c); + primitive_c->SetQuantType(schema::QuantType_PostTraining); } return RET_OK; } diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h index 1c47ee4dff750c718d56ac7ff97b4aeb2a39c3e6..e2cdfdfecb5332fb31c9019d73b004d9406ae37e 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h @@ -95,10 +95,10 @@ class PostTrainingQuantizer : public Quantizer { STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr); STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr); - STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr primitiveT_value, bool perchannel, + STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr primitive_c, bool perchannel, bool depthwise); - STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr primitiveT_value); + STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr primitive_c); }; struct DivergInfo; diff --git a/mindspore/lite/tools/converter/quantizer/quant_cast.cc b/mindspore/lite/tools/converter/quantizer/quant_cast.cc index 9e0c9d4cfe1db970e235d9e84f5418471350a061..c205452c2680c590623633a7a1a310c0c6ac8496 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_cast.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_cast.cc @@ -44,17 +44,17 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { bool first = true; for (auto &cnode : cnodes) { - auto primitiveT_value = GetValueNode>(cnode->input(0)); + auto primitive_c = GetValueNode>(cnode->input(0)); auto curnode_quant_type = schema::QuantType_QUANT_NONE; - if (primitiveT_value == nullptr) { - MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope(); + if (primitive_c == nullptr) { + MS_LOG(WARNING) << "primitive_c is nullptr: " << cnode->fullname_with_scope(); } else { - curnode_quant_type = primitiveT_value->GetQuantType(); + curnode_quant_type = primitive_c->GetQuantType(); } if (first) { if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) { auto value_node = - NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams().front()); + NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitive_c->GetInputQuantParams().front()); std::vector op_inputs = {value_node, cnode->input(1)}; auto quant_cast_cnode = graph->NewCNode(op_inputs); quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); @@ -72,24 +72,24 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { continue; } auto input_cnode = std::dynamic_pointer_cast(input_node); - auto input_cnode_primitiveT_value = GetValueNode>(input_cnode->input(0)); - if (input_cnode_primitiveT_value == nullptr) { + auto input_cnode_primitive_c = GetValueNode>(input_cnode->input(0)); + if (input_cnode_primitive_c == nullptr) { MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " << " PrimitiveC is null"; continue; } - auto input_cnode_quant_type = input_cnode_primitiveT_value->GetQuantType(); + auto input_cnode_quant_type = input_cnode_primitive_c->GetQuantType(); if (curnode_quant_type != input_cnode_quant_type) { ValueNodePtr value_node = nullptr; if (curnode_quant_type == schema::QuantType_PostTraining && input_cnode_quant_type == schema::QuantType_QUANT_NONE) { value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, - primitiveT_value->GetInputQuantParams().front()); + primitive_c->GetInputQuantParams().front()); } else if (curnode_quant_type == schema::QuantType_QUANT_NONE && input_cnode_quant_type == schema::QuantType_PostTraining) { value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32, - input_cnode_primitiveT_value->GetInputQuantParams().front()); + input_cnode_primitive_c->GetInputQuantParams().front()); } if (value_node == nullptr) { MS_LOG(WARNING) << "value_node is null! " diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 3032ce0908d645abc068abf470aec61c4a79a9b5..e61e94291057b63d3f773ed338399322dbe0966b 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -87,13 +87,13 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { } auto cnode = std::dynamic_pointer_cast(node); - auto primitiveT_value = GetValueNode>(cnode->input(0)); - if (primitiveT_value == nullptr) { - MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope(); + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c == nullptr) { + MS_LOG(WARNING) << "primitive_c is nullptr: " << cnode->fullname_with_scope(); return false; } - auto type = (schema::PrimitiveType)primitiveT_value->Type(); + auto type = (schema::PrimitiveType)primitive_c->Type(); MS_LOG(INFO) << "Primitive type: " << type; static const std::vector uint8OpList = { schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, @@ -279,7 +279,7 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl return RET_OK; } -STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primitiveT_value, QuantType quantType, +STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primitive_c, QuantType quantType, int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) { auto dims = weight->tensor_shape(); if (per_channel) { @@ -450,7 +450,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primiti MS_LOG(ERROR) << "quant_params empty"; return RET_ERROR; } - primitiveT_value->AddInputQuantParam(quant_params); + primitive_c->AddInputQuantParam(quant_params); return RET_OK; } diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index 7a95fc90cc85203b0c59b2642508d78e2a333fc0..352f969c1043299a7a93b9689b9dbb2c6afd9993 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -118,7 +118,7 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan }(); } -STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primitiveT_value, QuantType quantType, +STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primitive_c, QuantType quantType, int quant_max, int quant_min, size_t bitNum = UINT8_QUANTIZATION, bool per_channel = false, bool depth_wise = false); diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index 49d4af2cae9a9ff919b82a4beaea5ef2e754979d..a36186024325cecf84da8b656c83ad7b25010cd4 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -135,6 +135,26 @@ void FreeInputTensor(std::vector *input_tensor) { } return; } +schema::Primitive *PackPrimitiveT(const CNodePtr &cnode) { + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr"; + return nullptr; + } + + auto *lite_primitive = primitive_c->GetPrimitiveT(); + if (lite_primitive == nullptr) { + MS_LOG(ERROR) << "Primitive in primitive_c is nullptr"; + return nullptr; + } + + flatbuffers::FlatBufferBuilder builder(1024); + auto offset = schema::Primitive::Pack(builder, lite_primitive); + builder.Finish(offset); + auto buf = builder.GetBufferPointer(); + auto primitive = flatbuffers::GetRoot(buf); + return const_cast(primitive); +} const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { CheckIfFuncGraphIsNull(func_graph); @@ -155,10 +175,16 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An } MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope(); auto output_nums = GetOutputTensorNum(input_cnode); - auto primitiveT_value = GetValueNode>(input_cnode->input(0)); std::vector output_tensors{output_nums, new Tensor()}; - primitiveT_value->InferShape(input_tensors, output_tensors); - auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, primitiveT_value.get()); + auto scheam_primitive = PackPrimitiveT(input_cnode); + auto lite_primitive = mindspore::lite::PrimitiveC::UnPackFromSchemaPrimitive(scheam_primitive); + if (lite_primitive == nullptr) { + MS_LOG(ERROR) << "constant_folding schedule node lite primitive nullptr"; + FreeInputTensor(&input_tensors); + return nullptr; + } + lite_primitive->InferShape(input_tensors, output_tensors); + auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive); if (lite_kernel == nullptr) { MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; FreeInputTensor(&input_tensors); diff --git a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc index 99ac179907de70ce0f004d63a55662d4935c6138..67c815c80566ca2909abd10d739d5b6433871f77 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_activation_fusion.cc @@ -62,17 +62,17 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c } auto conv_node = pre_node->cast(); auto node_type = GetCNodeType(conv_node); - auto primitiveT_value = GetValueNode>(conv_node->input(0)); - MS_ASSERT(primitiveT_value); + auto primitive_c = GetValueNode>(conv_node->input(0)); + MS_ASSERT(primitive_c); if (node_type == schema::PrimitiveType_Conv2D) { - MS_ASSERT(utils::isa>(primitiveT_value)); - auto primc = utils::cast>(primitiveT_value); + MS_ASSERT(utils::isa>(primitive_c)); + auto primc = utils::cast>(primitive_c); MS_ASSERT(primc != nullptr); primc->SetActivationType(activation_type); return pre_node; } else if (node_type == schema::PrimitiveType_DepthwiseConv2D) { - MS_ASSERT(utils::isa>(primitiveT_value)); - auto primc = utils::cast>(primitiveT_value); + MS_ASSERT(utils::isa>(primitive_c)); + auto primc = utils::cast>(primitive_c); MS_ASSERT(primc != nullptr); primc->SetActivationType(activation_type); return pre_node; diff --git a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc index f746240552f345a8af40e064375f9de0da98c223..d6974365959d246034d69e2905ac12d155eb80f5 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc @@ -160,22 +160,22 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons auto conv_node = conv_node_anf->cast(); CheckIfCNodeIsNull(conv_node); GenConvNewBias(func_graph, conv_node, add_node); - auto primitiveT_value = GetValueNode>(conv_node->input(0)); - MS_ASSERT(primitiveT_value != nullptr); - auto type = primitiveT_value->Type(); + auto primitive_c = GetValueNode>(conv_node->input(0)); + MS_ASSERT(primitive_c != nullptr); + auto type = primitive_c->Type(); if (type == schema::PrimitiveType_Conv2D) { - MS_ASSERT(utils::isa>(primitiveT_value)); - auto primc = utils::cast>(primitiveT_value); + MS_ASSERT(utils::isa>(primitive_c)); + auto primc = utils::cast>(primitive_c); MS_ASSERT(primc != nullptr); primc->SetHasBias(true); } else if (type == schema::PrimitiveType_DepthwiseConv2D) { - MS_ASSERT(utils::isa>(primitiveT_value)); - auto primc = utils::cast>(primitiveT_value); + MS_ASSERT(utils::isa>(primitive_c)); + auto primc = utils::cast>(primitive_c); MS_ASSERT(primc != nullptr); primc->SetHasBias(true); } else if (type == schema::PrimitiveType_DeConv2D) { - MS_ASSERT(utils::isa>(primitiveT_value)); - auto primc = utils::cast>(primitiveT_value); + MS_ASSERT(utils::isa>(primitive_c)); + auto primc = utils::cast>(primitive_c); MS_ASSERT(primc != nullptr); primc->SetHasBias(true); } else { diff --git a/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc index 5e391ec7a30e38f6f1e53d12154e8c3f516fe3f8..e348691bcff2ecf7ffd83d219ee80b1ee7701959 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_bn_fusion.cc @@ -115,14 +115,14 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern AnfNodePtr bn_scale_node = nullptr; AnfNodePtr bn_bias_node = nullptr; float eps = 0; - auto primitiveT_value = GetValueNode>(bn_node->input(0)); + auto primitive_c = GetValueNode>(bn_node->input(0)); if (GetCNodeType(bn_node) == schema::PrimitiveType_BatchNorm) { bn_mean_node = bn_node->input(kCaffeBNMeanIndex); bn_variance_node = bn_node->input(kCaffeBNVarIndex); CheckIfNodeIsParam(bn_mean_node); CheckIfNodeIsParam(bn_variance_node); - MS_ASSERT(utils::isa>(primitiveT_value)); - auto primc = utils::cast>(primitiveT_value); + MS_ASSERT(utils::isa>(primitive_c)); + auto primc = utils::cast>(primitive_c); MS_ASSERT(primc != nullptr); eps = primc->GetEpsilon(); } else if (GetCNodeType(bn_node) == schema::PrimitiveType_FusedBatchNorm) { @@ -130,8 +130,8 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern bn_bias_node = bn_node->input(kTFBNBiasIndex); bn_mean_node = bn_node->input(kTFBNMeanIndex); bn_variance_node = bn_node->input(kTFBNVarIndex); - MS_ASSERT(utils::isa>(primitiveT_value)); - auto primc = utils::cast>(primitiveT_value); + MS_ASSERT(utils::isa>(primitive_c)); + auto primc = utils::cast>(primitive_c); MS_ASSERT(primc != nullptr); eps = primc->GetEpsilon(); } else { diff --git a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc index 9256a60a635045006c48a41e84b5ea7e29dbfaff..e7490933c225f0b6d71782b195acd01fbed2a49b 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc @@ -97,17 +97,17 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co GenNewConvTensor(func_graph, conv_node, kernel_nums, trans_scale, trans_bias); delete[] trans_bias; delete[] trans_scale; - auto primitiveT_value = GetValueNode>(conv_node->input(0)); - MS_ASSERT(primitiveT_value != nullptr); - auto type = primitiveT_value->Type(); + auto primitive_c = GetValueNode>(conv_node->input(0)); + MS_ASSERT(primitive_c != nullptr); + auto type = primitive_c->Type(); if (type == schema::PrimitiveType_Conv2D) { - MS_ASSERT(utils::isa>(primitiveT_value)); - auto primc = utils::cast>(primitiveT_value); + MS_ASSERT(utils::isa>(primitive_c)); + auto primc = utils::cast>(primitive_c); MS_ASSERT(primc != nullptr); primc->SetHasBias(true); } else if (type == schema::PrimitiveType_DepthwiseConv2D) { - MS_ASSERT(utils::isa>(primitiveT_value)); - auto primc = utils::cast>(primitiveT_value); + MS_ASSERT(utils::isa>(primitive_c)); + auto primc = utils::cast>(primitive_c); MS_ASSERT(primc != nullptr); primc->SetHasBias(true); } else {