From 204ab1157220e5d55dea15f60e6ca34cf5972f58 Mon Sep 17 00:00:00 2001 From: zhengjun10 Date: Sat, 22 Aug 2020 17:00:41 +0800 Subject: [PATCH] add bn convert scale pass --- mindspore/lite/test/models_caffe.cfg | 10 +- .../tools/converter/graphdef_transform.cc | 13 + .../legacy_optimizer/fusion/CMakeLists.txt | 1 + .../fusion/batchnorm_convert_scale_pass.cc | 383 ++++++++++++++++++ .../fusion/batchnorm_convert_scale_pass.h | 100 +++++ 5 files changed, 502 insertions(+), 5 deletions(-) create mode 100644 mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_convert_scale_pass.cc create mode 100644 mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_convert_scale_pass.h diff --git a/mindspore/lite/test/models_caffe.cfg b/mindspore/lite/test/models_caffe.cfg index 6c7dc8728..276548234 100644 --- a/mindspore/lite/test/models_caffe.cfg +++ b/mindspore/lite/test/models_caffe.cfg @@ -16,7 +16,7 @@ tracking mtk_isface mtk_landmark mtk_pose_tuku -mtk_face_recognition_v1 +# mtk_face_recognition_v1 mtk_2012_ATLANTA_10class_20190614_v41 mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified @@ -28,16 +28,16 @@ ml_hand_detection ml_ocr_cn ml_ocr_sfz_detect_0325 ml_hardware_liveness -ml_liveness_detect_landmark +# ml_liveness_detect_landmark ml_face_contour 2012_ATLANTA_1class_20190621_v4.x_nomean ml_handpose ml_ocr_sfz_add_final_0325 -ml_hardware_pose +# ml_hardware_pose ml_bank_recog 2012_ATLANTA_10class_20190131_v4.0 mnet -recognition +# recognition ml_face_landmark model_hebing_3branch hiai_cv_focusShootOCRModel_07 @@ -50,7 +50,7 @@ hiai_cpu_face_hat hiai_video_seg hiai_semantic_seg hiai_human_seg -hiai_face_recognition_1 +# hiai_face_recognition_1 hiai_cpu_face_detect detect-mbv1-shortcut-400-400_nopostprocess_simplified detect_mbv1_640_480_nopostprocess_simplified diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 741a06862..9d5d45c7c 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -27,6 +27,7 @@ #include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" +#include "tools/converter/legacy_optimizer/fusion/batchnorm_convert_scale_pass.h" #include "tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h" #include "tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h" #include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" @@ -126,6 +127,17 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } + // postconvert pass + { + Optimizer fusionOptimizer; + fusionOptimizer.AddPass(new (std::nothrow) BatchNormConvertScalePass()); + fusionOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + status = fusionOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run fusionOptimizer BatchNormConvertScalePass Failed"; + return status; + } + } // format transform if (ctx.formatTrans) { Optimizer formatTransOptimizer; @@ -187,6 +199,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { return status; } } + // topological sorting { Optimizer topologicalOptimizer; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt index c31eed766..238131a1a 100755 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt @@ -6,6 +6,7 @@ add_library(fusion_mid OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_fold_fusion_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_fusion_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_transpose_fusion_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_convert_scale_pass.cc ) target_link_libraries(fusion_mid securec) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_convert_scale_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_convert_scale_pass.cc new file mode 100644 index 000000000..fbbd4adcb --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_convert_scale_pass.cc @@ -0,0 +1,383 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/fusion/batchnorm_convert_scale_pass.h" +#include +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" +#include "tools/common/graph_util.h" +#include "tools/common/tensor_util.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" +#include "src/common/op_utils.h" + +namespace mindspore { +namespace lite { +#define CAFFE_BATCHNORM_OP_WEIGHT_NUM 2 +#define TF_BATCHNORM_OP_WEIGHT_NUM 4 +#define CAFFE_BATCHNORM_MEAN_INDEX 0 +#define CAFFE_BATCHNORM_VARIANCE_INDEX 1 +#define TF_BATCHNORM_SCALE_INDEX 0 +#define TF_BATCHNORM_BIAS_INDEX 1 +#define TF_BATCHNORM_MEAN_INDEX 2 +#define TF_BATCHNORM_VARIANCE_INDEX 3 +namespace { +constexpr const float EPS = 1e-8; +constexpr const float EPS_DEFAULT_FLOAT = 1e-5; +constexpr const float POW_NUM = 0.5; +constexpr const int32_t NCHW_DIM_C = 1; +} +STATUS BatchNormConvertScalePass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); } + +STATUS BatchNormConvertScalePass::DefinePattern() { + // with preNode + { + auto inputOp = std::make_shared(); + inputOp->id = inputOpName; + inputOp->types = {schema::PrimitiveType_NONE}; + inputOp->isPlaceHold = true; + + auto bnOp = std::make_shared(); + bnOp->id = bnOpName; + bnOp->types = {schema::PrimitiveType_FusedBatchNorm, schema::PrimitiveType_BatchNorm}; + bnOp->left = inputOp; + + std::unique_ptr fusionPattern(new(std::nothrow) FusionPattern(bnPatternName)); + if (fusionPattern == nullptr) { + MS_LOG(ERROR) << "new fusionPattern failed"; + return RET_ERROR; + } + fusionPattern->AddPatternOp(inputOp); + fusionPattern->AddPatternOp(bnOp); + fusionPattern->Finish(); + + this->patterns.emplace_back(fusionPattern.release()); + } + + return RET_OK; +} +STATUS BatchNormConvertScalePass::DoFusion(MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + if (patternName != bnPatternName) { + MS_LOG(ERROR) << "BatchNormConvertScale-Fusion match failed"; + return RET_PARAM_INVALID; + } + auto status = FindNodes(graph, matchedPath); + if (status != RET_OK) { + MS_LOG(ERROR) << "FindNodes failed: " << status; + return status; + } + auto type = bnNode->primitive->value.type; + if (type != schema::PrimitiveType_FusedBatchNorm && type != schema::PrimitiveType_BatchNorm) { + return RET_OK; + } + auto bnPath = matchedPath.at(bnOpName); + status = GetTransParam(graph, bnPath); + if (status != RET_OK) { + MS_LOG(ERROR) << "GetTransParam failed: " << status; + return status; + } + + status = GenNewScaleTensor(graph, bnPath); + if (status != RET_OK) { + MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status; + return status; + } + + status = ConvertBNToScale(graph, bnPath); + if (status != RET_OK) { + MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status; + return status; + } + return RET_OK; +} +STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std::shared_ptr &bnPath) { + auto scaleNode = std::unique_ptr(new(std::nothrow) CNodeT); + if (scaleNode == nullptr) { + MS_LOG(ERROR) << "new TransNode failed"; + return RET_ERROR; + } + scaleNode->name = bnNode->name; + scaleNode->primitive = std::make_unique(); + if (scaleNode->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + scaleNode->primitive->value.type = schema::PrimitiveType_Scale; + std::unique_ptr scaleParam(new ScaleT()); + if (scaleParam == nullptr) { + MS_LOG(ERROR) << "new transposeParam failed"; + return RET_ERROR; + } + scaleParam->axis = NCHW_DIM_C; + scaleNode->primitive->value.value = scaleParam.release(); + auto scaleIter = graph->nodes.begin() + bnPath->nodeIdx; + STATUS errorCode = RET_OK; + scaleIter = + InsertNode(graph, scaleIter, kBefore, 0, std::move(scaleNode), &errorCode, ScaleOpCopyer); + if (errorCode != RET_OK) { + MS_LOG(ERROR) << "InsertNode failed: %d"; // errorCode); + return errorCode; + } + auto &newScaleNode = *(scaleIter - 1); + graph->allTensors.emplace_back(std::move(newScaleWeightTensor)); + auto weightTensorIdx = graph->allTensors.size() - 1; + graph->allTensors.emplace_back(std::move(newScaleBiasTensor)); + auto biasTensorIdx = graph->allTensors.size() - 1; + newScaleNode->inputIndex.push_back(weightTensorIdx); + newScaleNode->inputIndex.push_back(biasTensorIdx); + // delete bn node + auto status = IsolateOneWayNode(graph, bnPath->nodeIdx + 1, true); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode " << bnNode->name.c_str() << " failed, error: " << status; + return status; + } + return RET_OK; +} +STATUS BatchNormConvertScalePass::GenNewScaleTensor(MetaGraphT *graph, const std::shared_ptr &bnPath) { + MS_ASSERT(graph != nullptr); + GetTransParam(graph, bnPath); + newScaleWeightTensor = std::unique_ptr(new(std::nothrow) TensorT); + if (newScaleWeightTensor == nullptr) { + MS_LOG(ERROR) << "new weightTensor failed"; + return RET_ERROR; + } + newScaleWeightTensor->dataType = bnMeanTensor->dataType; + newScaleWeightTensor->format = bnMeanTensor->format; + newScaleWeightTensor->refCount = schema::NodeType_ValueNode; + newScaleWeightTensor->dims = bnMeanTensor->dims; + auto weightShapeSize = GetShapeSize(*bnMeanTensor); + newScaleWeightTensor->data.resize(weightShapeSize * sizeof(float)); + auto ret = memcpy_s(newScaleWeightTensor->data.data(), weightShapeSize * sizeof(float), transScale, + weightShapeSize * sizeof(float)); + if (ret != RET_OK) { + delete transScale; + MS_LOG(ERROR) << "memcpy error: " << ret; + return RET_ERROR; + } + + newScaleBiasTensor = std::unique_ptr(new(std::nothrow) TensorT); + if (newScaleBiasTensor == nullptr) { + MS_LOG(ERROR) << "new weightTensor failed"; + return RET_ERROR; + } + newScaleBiasTensor->dataType = bnMeanTensor->dataType; + newScaleBiasTensor->format = bnMeanTensor->format; + + newScaleBiasTensor->refCount = schema::NodeType_ValueNode; + newScaleBiasTensor->dims = bnMeanTensor->dims; + weightShapeSize = GetShapeSize(*bnMeanTensor); + newScaleBiasTensor->data.resize(weightShapeSize * sizeof(float)); + ret = memcpy_s(newScaleBiasTensor->data.data(), weightShapeSize * sizeof(float), transBias, + weightShapeSize * sizeof(float)); + if (ret != RET_OK) { + delete transBias; + MS_LOG(ERROR) << "memcpy error: " << ret; + return RET_ERROR; + } + return RET_OK; +} + +STATUS BatchNormConvertScalePass::FindNodes(MetaGraphT *graph, + const std::unordered_map> &matchedPath) { + MS_ASSERT(graph != nullptr); + auto inputPath = matchedPath.at(inputOpName); + auto bnPath = matchedPath.at(bnOpName); + MS_ASSERT(inputPath != nullptr); + MS_ASSERT(bnPath != nullptr); + if (inputPath->subGraphIdx != bnPath->subGraphIdx) { + MS_LOG(ERROR) << "matched nodes should from same subGraph"; + return RET_ERROR; + } + MS_ASSERT(graph->nodes.size() > inputPath->nodeIdx); + MS_ASSERT(graph->nodes.size() > bnPath->nodeIdx); + inputNode = graph->nodes.at(inputPath->nodeIdx).get(); + bnNode = graph->nodes.at(bnPath->nodeIdx).get(); + MS_ASSERT(inputNode != nullptr); + MS_ASSERT(bnNode != nullptr); + return RET_OK; +} +STATUS BatchNormConvertScalePass::GetTransParam(MetaGraphT *graph, const std::shared_ptr &bnPath) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(bnPath != nullptr); + + BNWeightTensors bnWeightTensors; + + auto status = GetBnWeightTensors(graph, bnPath, &bnWeightTensors); + if (status != RET_OK) { + MS_LOG(ERROR) << "GetBnWeightTensors error"; + return status; + } + auto *meanTensor = bnWeightTensors.meanTensor; + auto *varianceTensor = bnWeightTensors.varianceTensor; + auto *scaleTensor = bnWeightTensors.scaleTensor; + auto *biasTensor = bnWeightTensors.biasTensor; + + auto *meanData = reinterpret_cast(meanTensor->data.data()); + auto *varianceData = reinterpret_cast(varianceTensor->data.data()); + + eps = EPS_DEFAULT_FLOAT; + status = GetBnEpsilon(graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "GetBnEpsilon failed"; + return status; + } + this->transScale = new(std::nothrow) float[bnChannel]; + this->transBias = new(std::nothrow) float[bnChannel]; + // cal transScale, tf : scale/sqrt(variance + eps); caffe : 1/sqrt(variance + eps) + if (memcpy_s(transScale, bnChannel * sizeof(float), varianceData, bnChannel * sizeof(float)) != 0) { + MS_LOG(ERROR) << "memcpy_s transScale error"; + return RET_ERROR; + } + // 1/sqrt(variance + eps) + for (uint32_t i = 0; i < bnChannel; i++) { + float tmp = transScale[i] + eps; + tmp = pow(tmp, POW_NUM); + transScale[i] = 1 / tmp; + } + + if (scaleTensor != nullptr) { + auto *scaleData = reinterpret_cast(scaleTensor->data.data()); + // scale/sqrt(variance + eps) + for (uint32_t i = 0; i < bnChannel; i++) { + transScale[i] *= scaleData[i]; + } + } + + // cal transBias, tf : -scale*mean/sqrt(variance + eps) + bias; caffe : -mean/sqrt(variance + eps) + // -mean/sqrt(variance + eps) + for (uint32_t i = 0; i < bnChannel; i++) { + transBias[i] = -meanData[i] * transScale[i]; + } + + if (biasTensor != nullptr) { + auto *biasData = reinterpret_cast(biasTensor->data.data()); + // -scale*mean/sqrt(variance + eps) + bias + for (uint32_t i = 0; i < bnChannel; i++) { + transBias[i] += biasData[i]; + } + } + + return RET_OK; +} + +// BatchNorm weight Tensor definition: +// caffe +// estimated_mean --0 +// estimated_variance --1 +// tensorflow +// scale -- 0 +// bias --1 +// estimated_mean --2 +// estimated_variance --3 +STATUS BatchNormConvertScalePass::GetBnWeightTensors(MetaGraphT *graph, const std::shared_ptr &bnPath, + BNWeightTensors* bnWeightTensors) { + if (graph == nullptr || bnPath == nullptr) { + MS_LOG(ERROR) << "null pointer dereferencing."; + return RET_NULL_PTR; + } + MS_ASSERT(graph->allTensors.size() > bnNode->inputIndex.at(1)); + auto bnWeightTensorIdxes = bnNode->inputIndex; + bnWeightTensorIdxes.erase(bnWeightTensorIdxes.begin()); + if (bnWeightTensorIdxes.size() == CAFFE_BATCHNORM_OP_WEIGHT_NUM) { + bnWeightTensors->meanTensor = graph->allTensors.at(bnWeightTensorIdxes[CAFFE_BATCHNORM_MEAN_INDEX]).get(); + bnWeightTensors->varianceTensor = graph->allTensors.at(bnWeightTensorIdxes[CAFFE_BATCHNORM_VARIANCE_INDEX]).get(); + } else if (bnWeightTensorIdxes.size() == TF_BATCHNORM_OP_WEIGHT_NUM) { + bnWeightTensors->scaleTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_SCALE_INDEX]).get(); + bnWeightTensors->biasTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_BIAS_INDEX]).get(); + bnWeightTensors->meanTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_MEAN_INDEX]).get(); + bnWeightTensors->varianceTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_VARIANCE_INDEX]).get(); + } else { + MS_LOG(ERROR) << "BatchNorm should has 2 or 4 weight tensors, current number of weight tensors: " + << bnWeightTensorIdxes.size(); + return RET_ERROR; + } + + if (bnWeightTensors->meanTensor == nullptr) { + MS_LOG(ERROR) << "BatchNorm's mean tensor is nullptr"; + return RET_ERROR; + } + + if (bnWeightTensors->varianceTensor == nullptr) { + MS_LOG(ERROR) << "BatchNorm's variance tensor is nullptr"; + return RET_ERROR; + } + bnChannel = bnWeightTensors->meanTensor->data.size() * sizeof(uint8_t) / sizeof(float); + if (bnChannel <= 0) { + MS_LOG(ERROR) << "BatchNorm's channel less or equal 0"; + return RET_ERROR; + } + bnMeanTensor = bnWeightTensors->meanTensor; + if (bnChannel != bnWeightTensors->varianceTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { + MS_LOG(ERROR) << "conv kernel num expected to be equal to variance size"; + return RET_ERROR; + } + + if (bnWeightTensors->scaleTensor != nullptr) { + if (bnChannel != bnWeightTensors->scaleTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { + MS_LOG(ERROR) << "conv kernel num expected to be equal to scale size"; + return RET_ERROR; + } + } + + if (bnWeightTensors->biasTensor != nullptr) { + if (bnChannel != bnWeightTensors->biasTensor->data.size() * sizeof(uint8_t) / sizeof(float)) { + MS_LOG(ERROR) << "conv kernel num expected to be equal to bias size"; + return RET_ERROR; + } + } + return RET_OK; +} + +STATUS BatchNormConvertScalePass::GetBnEpsilon(MetaGraphT *graph) { + if (graph == nullptr) { + MS_LOG(ERROR) << "null pointer dereferencing."; + return RET_NULL_PTR; + } + if (bnNode == nullptr) { + MS_LOG(ERROR) << "null pointer dereferencing."; + return RET_NULL_PTR; + } + if (bnNode->primitive->value.type == schema::PrimitiveType_FusedBatchNorm) { + eps = bnNode->primitive->value.AsFusedBatchNorm()->epsilon; + } else if (bnNode->primitive->value.type == schema::PrimitiveType_BatchNorm) { + eps = bnNode->primitive->value.AsBatchNorm()->epsilon; + } else { + MS_LOG(ERROR) << "match pattern has error, not BatchNorm node"; + return RET_ERROR; + } + + if (eps < EPS) { + eps = EPS_DEFAULT_FLOAT; + } + return RET_OK; +} + +BatchNormConvertScalePass::~BatchNormConvertScalePass() { + if (this->transScale != nullptr) { + delete (this->transScale); + } + if (this->transBias != nullptr) { + delete (this->transBias); + } +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_convert_scale_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_convert_scale_pass.h new file mode 100644 index 000000000..06a683370 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_convert_scale_pass.h @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_BATCHNORM_CONVERT_SCALE_PASS_H +#define MINDSPORE_PREDICT_BATCHNORM_CONVERT_SCALE_PASS_H + +#include +#include +#include +#include +#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" +#include "tools/common/graph_util.h" + +namespace mindspore { +namespace lite { +struct BNWeightTensors { + TensorT *meanTensor = nullptr; + TensorT *varianceTensor = nullptr; + TensorT *scaleTensor = nullptr; + TensorT *biasTensor = nullptr; +}; +class BatchNormConvertScalePass : public FusionPass { + public: + BatchNormConvertScalePass() = default; + + ~BatchNormConvertScalePass() override; + + STATUS DefinePattern() override; + + STATUS DoFusion(MetaGraphT *graph, const std::string &patternName, + std::unordered_map> &matchedPath) override; + + STATUS Run(MetaGraphT *graph) override; + + protected: + STATUS GetTransParam(MetaGraphT *graph, const std::shared_ptr &bnPath); + + // Get and check BNNode weight tensor + STATUS GetBnWeightTensors(MetaGraphT *graph, const std::shared_ptr &bnPath, BNWeightTensors* bnWeightTensors); + + STATUS GetBnEpsilon(MetaGraphT *graph); + + STATUS FindNodes(MetaGraphT *graph, const std::unordered_map> &matchedPath); + + STATUS GenNewScaleTensor(MetaGraphT *graph, const std::shared_ptr &bnPath); + + STATUS ConvertBNToScale(MetaGraphT *graph, const std::shared_ptr &bnPath); + + CNodeT *inputNode = nullptr; + CNodeT *bnNode = nullptr; + + std::string inputOpName = "Input"; + std::string bnOpName = "BatchNorm"; + std::string bnPatternName = "BnToScaleFusion"; + uint32_t bnChannel = 0; + float eps = 0; + TensorT *bnMeanTensor = nullptr; + float *transScale = nullptr; + float *transBias = nullptr; + std::unique_ptr newScaleWeightTensor = nullptr; + std::unique_ptr newScaleBiasTensor = nullptr; + + OpDefCopyer ScaleOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr { + std::unique_ptr newOpDef(new(std::nothrow) CNodeT); + if (newOpDef == nullptr) { + MS_LOG(ERROR) << "new OpDefT failed"; + return nullptr; + } + newOpDef->name = inOpDef->name; + newOpDef->quantType = inOpDef->quantType; + newOpDef->primitive = std::make_unique(); + newOpDef->primitive->value.type = schema::PrimitiveType_Scale; + auto scaleParam = new(std::nothrow) ScaleT; + if (scaleParam == nullptr) { + MS_LOG(ERROR) << "new scaleParam failed"; + return nullptr; + } + auto inParam = inOpDef->primitive->value.AsScale(); + MS_ASSERT(inParam != nullptr); + scaleParam->axis = inParam->axis; + newOpDef->primitive->value.value = scaleParam; + return std::move(newOpDef); + }; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_PREDICT_BATCHNORM_CONVERT_SCALE_PASS_H -- GitLab