提交 66271e41 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5333 fix memory leak

Merge pull request !5333 from hangq/primitive
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* <p>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
*
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -52,7 +52,7 @@ public class MSTensor {
this.setDataType(this.tensorPtr, dataType);
}
public byte[] getBtyeData() {
public byte[] getByteData() {
return this.getByteData(this.tensorPtr);
}
......
......@@ -18,6 +18,7 @@
#include <jni.h>
#include "common/ms_log.h"
#include "include/context.h"
#include "include/thread_pool_config.h"
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_context_Context_createContext(JNIEnv *env, jobject thiz,
jint device_type,
......@@ -44,13 +45,13 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_context_Context_creat
}
switch (cpu_bind_mode) {
case -1:
context->cpu_bind_mode_ = mindspore::lite::MID_CPU;
context->cpu_bind_mode_ = MID_CPU;
break;
case 0:
context->cpu_bind_mode_ = mindspore::lite::NO_BIND;
context->cpu_bind_mode_ = NO_BIND;
break;
case 1:
context->cpu_bind_mode_ = mindspore::lite::HIGHER_CPU;
context->cpu_bind_mode_ = HIGHER_CPU;
break;
default:
MS_LOGE("Invalid cpu_bind_mode : %d", cpu_bind_mode);
......
......@@ -118,9 +118,9 @@ extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getByte
return env->NewByteArray(0);
}
auto local_data_size = ms_tensor_ptr->Size();
auto ret = env->NewByteArray(local_data_size);
env->SetByteArrayRegion(ret, 0, local_data_size, local_data);
auto local_element_num = ms_tensor_ptr->ElementsNum();
auto ret = env->NewByteArray(local_element_num);
env->SetByteArrayRegion(ret, 0, local_element_num, local_data);
return ret;
}
......@@ -144,9 +144,9 @@ extern "C" JNIEXPORT jlongArray JNICALL Java_com_mindspore_lite_MSTensor_getLong
MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type());
return env->NewLongArray(0);
}
auto local_data_size = ms_tensor_ptr->Size();
auto ret = env->NewLongArray(local_data_size);
env->SetLongArrayRegion(ret, 0, local_data_size, local_data);
auto local_element_num = ms_tensor_ptr->ElementsNum();
auto ret = env->NewLongArray(local_element_num);
env->SetLongArrayRegion(ret, 0, local_element_num, local_data);
return ret;
}
......@@ -170,9 +170,9 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getIntDa
MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type());
return env->NewIntArray(0);
}
auto local_data_size = ms_tensor_ptr->Size();
auto ret = env->NewIntArray(local_data_size);
env->SetIntArrayRegion(ret, 0, local_data_size, local_data);
auto local_element_num = ms_tensor_ptr->ElementsNum();
auto ret = env->NewIntArray(local_element_num);
env->SetIntArrayRegion(ret, 0, local_element_num, local_data);
return ret;
}
......@@ -196,9 +196,9 @@ extern "C" JNIEXPORT jfloatArray JNICALL Java_com_mindspore_lite_MSTensor_getFlo
MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type());
return env->NewFloatArray(0);
}
auto local_data_size = ms_tensor_ptr->Size();
auto ret = env->NewFloatArray(local_data_size);
env->SetFloatArrayRegion(ret, 0, local_data_size, local_data);
auto local_element_num = ms_tensor_ptr->ElementsNum();
auto ret = env->NewFloatArray(local_element_num);
env->SetFloatArrayRegion(ret, 0, local_element_num, local_data);
return ret;
}
......
......@@ -100,8 +100,8 @@ kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<tensor::Tensor *
const std::vector<tensor::Tensor *> &out_tensors,
const PrimitiveC *primitive, const Context *ctx,
const kernel::KernelKey &key) {
MS_EXCEPTION_IF_NULL(primitive);
MS_EXCEPTION_IF_NULL(ctx);
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != ctx);
auto parameter = kernel::PopulateParameter(primitive);
if (parameter == nullptr) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
......
......@@ -33,9 +33,9 @@
namespace mindspore {
namespace lite {
int LiteSession::ConvertTensors(const lite::Model *model) {
MS_EXCEPTION_IF_NULL(model);
MS_ASSERT(nullptr != model);
auto meta_graph = model->GetMetaGraph();
MS_EXCEPTION_IF_NULL(meta_graph);
MS_ASSERT(nullptr != meta_graph);
uint32_t tensorCount = meta_graph->allTensors()->size();
for (uint32_t i = 0; i < tensorCount; i++) {
auto *srcTensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i);
......@@ -246,7 +246,7 @@ int LiteSession::CompileGraph(Model *model) {
std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputs() const { return this->input_vec_; }
int LiteSession::RunGraph(const session::KernelCallBack &before, const session::KernelCallBack &after) {
MS_EXCEPTION_IF_NULL(this->context_);
MS_ASSERT(this->context_);
if (before == nullptr && after == nullptr) {
return executor->Run(this->inputs_, this->outputs_, this->kernels_, this->context_->allocator.get());
} else {
......@@ -255,7 +255,7 @@ int LiteSession::RunGraph(const session::KernelCallBack &before, const session::
}
int LiteSession::Init(Context *context) {
MS_EXCEPTION_IF_NULL(context);
MS_ASSERT(nullptr != context);
this->context_ = new (std::nothrow) Context(context->thread_num_, context->allocator, context->device_ctx_);
if (this->context_ == nullptr) {
MS_LOG(ERROR) << "new context failed";
......@@ -276,7 +276,7 @@ int LiteSession::Init(Context *context) {
}
#endif
executor = new Executor();
MS_EXCEPTION_IF_NULL(executor);
MS_ASSERT(nullptr != executor);
return RET_OK;
}
......
......@@ -101,7 +101,7 @@ int ModelImpl::BuildOps() {
MS_LOG(ERROR) << "mete_graph is nullptr";
return -1;
}
MS_EXCEPTION_IF_NULL(meta_graph_->nodes());
MS_ASSERT(nullptr != meta_graph_->nodes());
for (size_t i = 0; i < meta_graph_->nodes()->size(); i++) {
auto cNode = meta_graph_->nodes()->GetAs<schema::CNode>(i);
auto name = cNode->name()->str();
......@@ -129,17 +129,17 @@ Model *Model::Import(const char *model_buf, size_t size) {
Model::~Model() { delete (this->model_impl_); }
mindspore::lite::PrimitiveC *Model::GetOp(const std::string &name) const {
MS_EXCEPTION_IF_NULL(model_impl_);
MS_ASSERT(nullptr != model_impl_);
return const_cast<PrimitiveC *>(model_impl_->GetOp(name));
}
void Model::FreeMetaGraph() {
MS_EXCEPTION_IF_NULL(model_impl_);
MS_ASSERT(nullptr != model_impl_);
return model_impl_->FreeMetaGraph();
}
const schema::MetaGraph *Model::GetMetaGraph() const {
MS_EXCEPTION_IF_NULL(model_impl_);
MS_ASSERT(nullptr != model_impl_);
return model_impl_->meta_graph();
}
......
......@@ -31,8 +31,8 @@ class ParamValueLite : public Value {
ParamValueLite() : tensor_addr_(nullptr), tensor_size_(0) {}
virtual ~ParamValueLite() {
if (tensor_addr_ != nullptr) {
auto tensor_mem = reinterpret_cast<char*>(tensor_addr_);
delete tensor_mem;
auto tensor_mem = reinterpret_cast<char *>(tensor_addr_);
delete[](tensor_mem);
tensor_addr_ = nullptr;
tensor_size_ = 0;
}
......
......@@ -277,7 +277,7 @@ kernel::LiteKernel *CpuArithmeticGradFp32KernelCreator(const std::vector<lite::t
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_EXCEPTION_IF_NULL(opParameter);
MS_ASSERT(nullptr != opParameter);
if (opParameter == nullptr) {
return nullptr;
}
......
......@@ -206,7 +206,7 @@ int SubGraphOpenCLKernel::MallocTensorWithReuse() {
output->set_allocator(allocator_);
}
for (auto input_kernel : kernel->in_kernels()) {
MS_EXCEPTION_IF_NULL(input_kernel);
MS_ASSERT(nullptr != input_kernel);
auto ret = input_kernel->DecOutTensorRefCount();
if (0 != ret) {
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed";
......@@ -214,21 +214,21 @@ int SubGraphOpenCLKernel::MallocTensorWithReuse() {
}
}
for (auto kernel : out_kernels_) {
MS_EXCEPTION_IF_NULL(kernel);
MS_ASSERT(nullptr != kernel);
auto ret = kernel->DecOutTensorRefCount();
if (0 != ret) {
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed";
}
}
for (auto kernel : in_convert_ops_) {
MS_EXCEPTION_IF_NULL(kernel);
MS_ASSERT(nullptr != kernel);
auto ret = kernel->DecOutTensorRefCount();
if (0 != ret) {
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed";
}
}
for (auto kernel : out_convert_ops_) {
MS_EXCEPTION_IF_NULL(kernel);
MS_ASSERT(nullptr != kernel);
auto ret = kernel->DecOutTensorRefCount();
if (0 != ret) {
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed";
......
......@@ -65,7 +65,7 @@ int OpenCLExecutor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tenso
}
}
for (auto input_kernel : kernel->in_kernels()) {
MS_EXCEPTION_IF_NULL(input_kernel);
MS_ASSERT(nullptr != input_kernel);
ret = input_kernel->DecOutTensorRefCount();
if (0 != ret) {
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed";
......
......@@ -77,10 +77,10 @@ int Scheduler::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels) {
}
int Scheduler::InferShape(const lite::Model *model, std::vector<tensor::Tensor *> *tensors) {
MS_EXCEPTION_IF_NULL(model);
MS_EXCEPTION_IF_NULL(tensors);
MS_ASSERT(nullptr != model);
MS_ASSERT(nullptr != tensors);
auto meta_graph = model->GetMetaGraph();
MS_EXCEPTION_IF_NULL(meta_graph);
MS_ASSERT(nullptr != meta_graph);
bool infer_shape_interrupt = false;
uint32_t kernelCount = meta_graph->nodes()->size();
for (uint32_t i = 0; i < kernelCount; i++) {
......@@ -121,10 +121,10 @@ int Scheduler::InferShape(const lite::Model *model, std::vector<tensor::Tensor *
int Scheduler::InitOp2Kernel(const lite::Model *model, std::vector<tensor::Tensor *> *tensors,
std::vector<kernel::LiteKernel *> *kernels) {
MS_EXCEPTION_IF_NULL(model);
MS_EXCEPTION_IF_NULL(tensors);
MS_ASSERT(nullptr != model);
MS_ASSERT(nullptr != tensors);
auto meta_graph = model->GetMetaGraph();
MS_EXCEPTION_IF_NULL(meta_graph);
MS_ASSERT(nullptr != meta_graph);
uint32_t kernelCount = meta_graph->nodes()->size();
auto graph_output_node_indexes = GetGraphOutputNodes(meta_graph);
for (uint32_t i = 0; i < kernelCount; i++) {
......
......@@ -23,7 +23,7 @@ namespace mindspore {
schema::MetaGraphT *TestTfliteParser::LoadAndConvert(const string &model_path, const string &weight_path) {
lite::TfliteModelParser parser;
meta_graph = parser.Parse(model_path, weight_path);
meta_graph = parser.ParseToFb(model_path, weight_path);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Parse to metaGraph return nullptr";
return nullptr;
......
......@@ -107,7 +107,7 @@ int Benchmark::ReadInputFile() {
}
auto inputData = cur_tensor->MutableData();
memcpy(inputData, binBuf, tensorDataSize);
delete binBuf;
delete[](binBuf);
}
}
return RET_OK;
......@@ -455,6 +455,12 @@ int Benchmark::RunBenchmark(const std::string &deviceType) {
}
if (!_flags->calibDataPath.empty()) {
status = MarkAccuracy();
for (auto &data : calibData) {
data.second->shape.clear();
data.second->data.clear();
delete data.second;
}
calibData.clear();
if (status != 0) {
MS_LOG(ERROR) << "Run MarkAccuracy error: " << status;
std::cout << "Run MarkAccuracy error: " << status << std::endl;
......@@ -472,16 +478,6 @@ int Benchmark::RunBenchmark(const std::string &deviceType) {
return status;
}
}
if (cleanData) {
for (auto &data : calibData) {
data.second->shape.clear();
data.second->data.clear();
delete data.second;
}
calibData.clear();
}
delete (session);
delete (model);
return RET_OK;
......
......@@ -138,7 +138,6 @@ class MS_API Benchmark {
std::vector<mindspore::tensor::MSTensor *> msInputs;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> msOutputs;
std::unordered_map<std::string, CheckTensor *> calibData;
bool cleanData = true;
};
int MS_API RunBenchmark(int argc, const char **argv);
......
......@@ -35,7 +35,7 @@ OpDefCopyer GetSimpleOpCopyer() {
newCNode->quantType = inCNode->quantType;
newCNode->primitive = std::make_unique<schema::PrimitiveT>();
newCNode->primitive->value.type = inCNode->primitive->value.type;
return std::move(newCNode);
return newCNode;
};
}
......@@ -96,7 +96,7 @@ std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size
preNodeIdx.emplace_back(i);
}
}
return std::move(preNodeIdx);
return preNodeIdx;
}
std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
......@@ -111,7 +111,7 @@ std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const siz
postNodeIdx.emplace_back(i);
}
}
return std::move(postNodeIdx);
return postNodeIdx;
}
STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) {
......
......@@ -89,7 +89,7 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in
MS_LOG(ERROR) << "Dim size invalid";
return RET_ERROR;
}
std::unique_ptr<T> buf(new (std::nothrow) T[count]);
std::unique_ptr<T[]> buf(new (std::nothrow) T[count]);
if (buf == nullptr) {
MS_LOG(ERROR) << "new buf failed";
return RET_ERROR;
......
......@@ -24,7 +24,7 @@ std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT>
MS_ASSERT(tensor != nullptr);
auto &quantParams = tensor->quantParams;
if (!quantParams.empty()) {
return std::move(CopyQuantParamT(quantParams.front()));
return CopyQuantParamT(quantParams.front());
} else {
return nullptr;
}
......@@ -39,7 +39,7 @@ std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schem
dstQuantParam->max = srcQuantParam->max;
dstQuantParam->narrowRange = srcQuantParam->narrowRange;
dstQuantParam->numBits = srcQuantParam->numBits;
return std::move(dstQuantParam);
return dstQuantParam;
}
size_t GetElementSize(const TensorT &tensor) { return GetElementSize(TypeId(tensor.dataType)); }
......@@ -87,7 +87,7 @@ std::unique_ptr<TensorT> CopyTensorDefT(const std::unique_ptr<TensorT> &oldTenso
if (!oldTensor->quantParams.empty()) {
newTensor->quantParams.emplace_back(std::move(GetTensorQuantParam(oldTensor)));
}
return std::move(newTensor);
return newTensor;
}
size_t GetRefCount(MetaGraphT *graphT, uint32_t tensorIdx) {
......
......@@ -24,6 +24,8 @@
#include "tools/optimizer/fusion/conv_scale_fusion.h"
#include "tools/optimizer/fusion/conv_bn_fusion.h"
#include "tools/optimizer/fusion/constant_folding_fusion.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
using std::string;
namespace mindspore {
......@@ -32,10 +34,9 @@ AnfTransform::AnfTransform() = default;
AnfTransform::~AnfTransform() = default;
void AnfTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; }
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph) {
// return old_graph;
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const converter::Flags *config) {
MS_ASSERT(nullptr != old_graph);
// fusion const_fold
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
......@@ -54,6 +55,31 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph) {
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(old_graph);
// quant
if (config != nullptr && config->quantType == schema::QuantType_PostTraining) {
this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8);
if (mQuantizer == nullptr) {
MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
return nullptr;
}
}
if (mQuantizer != nullptr) {
mQuantizer->flags = *config;
auto status = mQuantizer->DoQuantize(new_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Quant failed " << status;
return nullptr;
}
quant::QuantCast quant_cast;
quant_cast.SetInputDataDType(kNumberTypeFloat32);
status = quant_cast.Run(new_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "add QuantCast error";
return nullptr;
}
}
return new_graph;
}
} // namespace lite
......
......@@ -17,11 +17,12 @@
#ifndef MS_ANF_TRANSFORM_H
#define MS_ANF_TRANSFORM_H
#include <memory>
#include "schema/inner/model_generated.h"
#include "tools/common/storage.h"
#include "tools/converter/converter_flags.h"
#include "ir/anf.h"
#include "tools/converter/quantizer/quantizer.h"
namespace mindspore {
namespace lite {
......@@ -29,15 +30,12 @@ class AnfTransform {
public:
AnfTransform();
virtual ~AnfTransform();
FuncGraphPtr Transform(const FuncGraphPtr &old_graph);
void SetGraphDef(schema::MetaGraphT *dstDef);
inline schema::MetaGraphT *GetOutput() { return graphDefT; }
FuncGraphPtr Transform(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);
protected:
schema::MetaGraphT *graphDefT = nullptr;
private:
std::unique_ptr<quant::Quantizer> mQuantizer = nullptr;
};
} // namespace lite
} // namespace mindspore
#endif
......@@ -70,41 +70,23 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
MS_ASSERT(nullptr != modelParser);
const std::string modelFile = flag->modelFile;
const std::string weightFile = flag->weightFile;
auto meta_graph = modelParser->Parse(modelFile, weightFile, flag->quantType);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Parse to metaGraph return nullptr";
return nullptr;
}
graph = ModelParser::Fb2Anf(meta_graph);
graph = modelParser->Parse(modelFile, weightFile, flag->quantType);
}
if (graph == nullptr) {
MS_LOG(ERROR) << "Parser/Import model return nullptr";
return nullptr;
}
graph = anfTransform->Transform(graph);
CreateQuantizer(graph, flag);
if (mQuantizer != nullptr) {
mQuantizer->flags = *flag;
auto status = mQuantizer->DoQuantize(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Quant failed " << status;
return nullptr;
}
quant::QuantCast quant_cast;
quant_cast.SetInputDataDType(kNumberTypeFloat32);
status = quant_cast.Run(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "add QuantCast error";
return nullptr;
}
graph = anfTransform->Transform(graph, flag);
if (graph == nullptr) {
MS_LOG(ERROR) << "Transform anf graph return nullptr";
return nullptr;
}
// anf -- fb
auto meta_graph = Export(graph);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta_graph return nullptr";
MS_LOG(ERROR) << "Export to meta graph return nullptr";
return nullptr;
}
......@@ -113,20 +95,13 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
transform->CreateQuantizer(flag);
auto status = transform->Transform(*flag);
if (status != 0) {
MS_LOG(ERROR) << "FBTransform model failed " << status;
MS_LOG(ERROR) << "Transform meta graph failed " << status;
return nullptr;
}
return meta_graph;
}
void Converter::CreateQuantizer(FuncGraphPtr func_graph, const converter::Flags *flags) {
auto type = flags->quantType;
if (type == mindspore::schema::QuantType_PostTraining) {
MS_LOG(INFO) << "create post training quantizer.";
mQuantizer.reset(new quant::PostTrainingQuantizer(func_graph, flags->configFile, 8));
}
}
int RunConverter(int argc, const char **argv) {
std::unique_ptr<converter::Flags> flags(new (std::nothrow) converter::Flags);
if (flags == nullptr) {
......
......@@ -25,7 +25,6 @@
#include "tools/anf_importer/anf_importer.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/anf_transform.h"
#include "tools/converter/quantizer/quantizer.h"
namespace mindspore {
namespace lite {
......@@ -34,15 +33,12 @@ class Converter {
Converter();
virtual ~Converter();
virtual schema::MetaGraphT *Convert(const lite::converter::Flags *flags);
void CreateQuantizer(FuncGraphPtr func_graph, const converter::Flags *flags);
void FreeFuncGraph(const FuncGraphPtr &func_graph);
protected:
ModelParser *modelParser = nullptr;
AnfImporter *modelImporter = nullptr;
GraphDefTransform *transform = nullptr;
AnfTransform *anfTransform = nullptr;
std::unique_ptr<quant::Quantizer> mQuantizer = nullptr;
};
int RunConverter(int argc, const char **argv);
......
......@@ -15,15 +15,12 @@
*/
#include "tools/converter/graphdef_transform.h"
#include <iostream>
#include <memory>
#include <string>
#include "schema/model_generated.h"
#include "utils/log_adapter.h"
#include "src/common/op_utils.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h"
// #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
......@@ -37,7 +34,6 @@
#include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
#include "tools/converter/quantizer/aware_quantizer.h"
#include "tools/converter/converter.h"
using std::string;
namespace mindspore::lite {
......@@ -72,7 +68,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
weightHardCodePass->SetFmkType(ctx.fmk);
weightFormatPass->SetQuantType(ctx.quantType);
weightFormatPass->SetFmkType(ctx.fmk);
// weightFormatPass->SetDstFormat(Format_KHWC);
weightFormatOptimizer.AddPass(weightHardCodePass);
weightFormatOptimizer.AddPass(weightFormatPass);
status = weightFormatOptimizer.Run(graphDefT);
......@@ -153,9 +148,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
formatTransOptimizer.AddPass(new EltwiseFormatTransPass());
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
// if (ctx.quantType == QuantType_AwareTraining) {
// formatTransOptimizer.AddPass(new (std::nothrow) FormatTransNodeQuantParamFillPass());
// }
status = formatTransOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed";
......
......@@ -74,7 +74,7 @@ class MatMulBiasAddFusionPass : public FusionPass {
std::transform(inParam->perm.begin(), inParam->perm.end(), transposeParam->perm.begin(),
[](const int32_t ele) { return ele; });
newOpDef->primitive->value.value = transposeParam;
return std::move(newOpDef);
return newOpDef;
};
};
} // namespace lite
......
......@@ -72,7 +72,7 @@ class DTypeTransPass : public GraphPass {
QuantDTypeCastParam->srcT = oldQuantDTypeCastParam->srcT;
QuantDTypeCastParam->dstT = oldQuantDTypeCastParam->dstT;
newCNode->primitive->value.value = QuantDTypeCastParam;
return std::move(newCNode);
return newCNode;
};
};
} // namespace lite
......
......@@ -32,16 +32,16 @@ class ModelParser {
virtual ~ModelParser() {}
virtual FuncGraphPtr ParseToAnf(const std::string &modelFile, const std::string &weightFile) {
auto *meta_graph = Parse(modelFile, weightFile);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Parse to metaGraph return nullptr";
return nullptr;
}
return Fb2Anf(Parse(modelFile, weightFile));
FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) {
auto *meta_graph = ParseToFb(modelFile, weightFile, quantType);
auto func_graph = this->Fb2Anf(meta_graph);
delete(meta_graph);
return func_graph;
}
virtual schema::MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) = 0;
virtual schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) = 0;
public:
static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) {
......
......@@ -31,7 +31,7 @@ CaffeModelParser::~CaffeModelParser() {}
const std::set<std::string> CaffeModelParser::skipedLayerType = {"Dropout"};
schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile,
schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile,
const std::string &weightFile,
const QuantType &quantType) {
if (ValidateFileStr(modelFile, ".prototxt") != RET_OK) {
......@@ -49,7 +49,7 @@ schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile,
return nullptr;
}
std::unique_ptr<schema::MetaGraphT> subGraphDef = std::make_unique<schema::MetaGraphT>();
auto metaGraph = std::make_unique<schema::MetaGraphT>();
TensorCache tensorCache;
caffe::NetParameter proto;
......@@ -57,7 +57,7 @@ schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile,
MS_LOG(ERROR) << "Read prototxt file failed, model path: " << modelFile;
return nullptr;
}
subGraphDef->name = proto.name();
metaGraph->name = proto.name();
caffe::NetParameter weight;
if (ReadProtoFromBinaryFile((const char *)weightFile.c_str(), &weight) != RET_OK) {
......@@ -71,22 +71,22 @@ schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile,
return nullptr;
}
status = ParseLayer(proto, weight, &tensorCache, subGraphDef.get());
status = ParseLayer(proto, weight, &tensorCache, metaGraph.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "ParseLayer failed " << status;
return nullptr;
}
status = SetGraphTensorIndex(proto, &tensorCache, subGraphDef.get());
status = SetGraphTensorIndex(proto, &tensorCache, metaGraph.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "Set inputTensor index and outputTensor index for graph failed!";
return nullptr;
}
subGraphDef->name = GetModelName(modelFile);
metaGraph->name = GetModelName(modelFile);
SetAllTensors(tensorCache, subGraphDef.get());
SetAllTensors(tensorCache, metaGraph.get());
return subGraphDef.release();
return metaGraph.release();
}
STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer,
......
......@@ -33,7 +33,7 @@ class CaffeModelParser : public ModelParser {
virtual ~CaffeModelParser();
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;
private:
......
......@@ -507,14 +507,14 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph)
}
}
MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile,
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile,
const std::string &weightFile,
const QuantType &quantType) {
if (ValidateFileStr(modelFile, ".onnx") != RET_OK) {
MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx";
return nullptr;
}
std::unique_ptr<schema::MetaGraphT> dst_graph = std::make_unique<schema::MetaGraphT>();
auto dst_graph = std::make_unique<schema::MetaGraphT>();
onnx::ModelProto onnx_model;
if (ReadOnnxModelFromBinary(modelFile, &onnx_model) != RET_OK) {
MS_LOG(ERROR) << "read onnx model fail";
......
......@@ -40,7 +40,7 @@ class OnnxModelParser : public ModelParser {
virtual ~OnnxModelParser();
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;
private:
......
......@@ -44,7 +44,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "Add") == 0) {
MS_LOG(DEBUG) << "parse TfliteAddParser";
std::unique_ptr<schema::AddT> attr = std::make_unique<schema::AddT>();
auto attr = std::make_unique<schema::AddT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......@@ -59,7 +59,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Sub") == 0) {
MS_LOG(DEBUG) << "parse TfliteSubParser";
std::unique_ptr<schema::SubT> attr = std::make_unique<schema::SubT>();
auto attr = std::make_unique<schema::SubT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......@@ -74,7 +74,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Mul") == 0) {
MS_LOG(DEBUG) << "parse TfliteMulParser";
std::unique_ptr<schema::MulT> attr = std::make_unique<schema::MulT>();
auto attr = std::make_unique<schema::MulT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......@@ -89,7 +89,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Div") == 0) {
MS_LOG(DEBUG) << "parse TfliteDivParser";
std::unique_ptr<schema::DivT> attr = std::make_unique<schema::DivT>();
auto attr = std::make_unique<schema::DivT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......@@ -113,7 +113,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "FloorMod") == 0) {
MS_LOG(DEBUG) << "parse TfliteFloorModParser";
std::unique_ptr<schema::FloorModT> attr = std::make_unique<schema::FloorModT>();
auto attr = std::make_unique<schema::FloorModT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......@@ -131,7 +131,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "SquaredDifference") == 0) {
MS_LOG(DEBUG) << "parse TfliteSquaredDifferenceParser";
std::unique_ptr<schema::SquaredDifferenceT> attr = std::make_unique<schema::SquaredDifferenceT>();
auto attr = std::make_unique<schema::SquaredDifferenceT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......@@ -140,7 +140,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Pow") == 0) {
MS_LOG(DEBUG) << "parse TflitePowParser";
std::unique_ptr<schema::PowerT> attr = std::make_unique<schema::PowerT>();
auto attr = std::make_unique<schema::PowerT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......@@ -152,7 +152,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Maximum") == 0) {
MS_LOG(DEBUG) << "parse TfliteMaximumParser";
std::unique_ptr<schema::MaximumT> attr = std::make_unique<schema::MaximumT>();
auto attr = std::make_unique<schema::MaximumT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......@@ -161,7 +161,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Minimum") == 0) {
MS_LOG(DEBUG) << "parse TfliteMinimumParser";
std::unique_ptr<schema::MinimumT> attr = std::make_unique<schema::MinimumT>();
auto attr = std::make_unique<schema::MinimumT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......@@ -202,7 +202,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
const char *node_name = node_name_str.data()->c_str();
if (std::strcmp(node_name, "Abs") == 0) {
MS_LOG(DEBUG) << "parse TfliteAbsParser";
std::unique_ptr<schema::AbsT> attr = std::make_unique<schema::AbsT>();
auto attr = std::make_unique<schema::AbsT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......@@ -211,7 +211,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Exp") == 0) {
MS_LOG(DEBUG) << "parse TfliteExpParser";
std::unique_ptr<schema::ExpT> attr = std::make_unique<schema::ExpT>();
auto attr = std::make_unique<schema::ExpT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......@@ -220,7 +220,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Sqrt") == 0) {
MS_LOG(DEBUG) << "parse TfliteSqrtParser";
std::unique_ptr<schema::SqrtT> attr = std::make_unique<schema::SqrtT>();
auto attr = std::make_unique<schema::SqrtT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......@@ -229,7 +229,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Rsqrt") == 0) {
MS_LOG(DEBUG) << "parse TfliteRsqrtParser";
std::unique_ptr<schema::RsqrtT> attr = std::make_unique<schema::RsqrtT>();
auto attr = std::make_unique<schema::RsqrtT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......@@ -238,7 +238,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Square") == 0) {
MS_LOG(DEBUG) << "parse TfliteSquareParser";
std::unique_ptr<schema::SquareT> attr = std::make_unique<schema::SquareT>();
auto attr = std::make_unique<schema::SquareT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......@@ -247,7 +247,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Sin") == 0) {
MS_LOG(DEBUG) << "parse TfliteSinParser";
std::unique_ptr<schema::SinT> attr = std::make_unique<schema::SinT>();
auto attr = std::make_unique<schema::SinT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......@@ -265,7 +265,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Log") == 0) {
MS_LOG(DEBUG) << "parse TfliteLogParser";
std::unique_ptr<schema::LogT> attr = std::make_unique<schema::LogT>();
auto attr = std::make_unique<schema::LogT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......@@ -274,7 +274,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Round") == 0) {
MS_LOG(DEBUG) << "parse TfliteRoundParser";
std::unique_ptr<schema::RoundT> attr = std::make_unique<schema::RoundT>();
auto attr = std::make_unique<schema::RoundT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......@@ -283,7 +283,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Ceil") == 0) {
MS_LOG(DEBUG) << "parse TfliteCeilParser";
std::unique_ptr<schema::CeilT> attr = std::make_unique<schema::CeilT>();
auto attr = std::make_unique<schema::CeilT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......@@ -292,7 +292,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "flOOR") == 0) {
MS_LOG(DEBUG) << "parse TfliteFloorParser";
std::unique_ptr<schema::FloorT> attr = std::make_unique<schema::FloorT>();
auto attr = std::make_unique<schema::FloorT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
......
......@@ -28,26 +28,25 @@ namespace mindspore {
namespace lite {
TfliteModelParser::TfliteModelParser() = default;
TfliteModelParser::~TfliteModelParser() = default;
TfliteModelParser::~TfliteModelParser() { delete[](this->tfliteModelBuf); }
std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *model_path) {
size_t size;
auto buf = ReadFile(model_path, &size);
if (buf == nullptr) {
tfliteModelBuf = ReadFile(model_path, &size);
if (tfliteModelBuf == nullptr) {
MS_LOG(ERROR) << "the file buffer is nullptr";
return nullptr;
}
flatbuffers::Verifier verify((const uint8_t *)buf, size);
flatbuffers::Verifier verify((const uint8_t *)tfliteModelBuf, size);
if (!tflite::VerifyModelBuffer(verify)) {
MS_LOG(ERROR) << "the buffer is invalid and fail to create graph";
return nullptr;
}
return tflite::UnPackModel(buf);
return tflite::UnPackModel(tfliteModelBuf);
}
STATUS TfliteModelParser::CopyConstTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const tflite::TensorT *tflite_tensor,
schema::TensorT *tensor) {
const tflite::TensorT *tflite_tensor, schema::TensorT *tensor) {
auto count = 1;
std::for_each(tflite_tensor->shape.begin(), tflite_tensor->shape.end(), [&](int32_t sha) { count *= sha; });
auto data_size = count * GetDataTypeSize(TypeId(tensor->dataType));
......@@ -95,8 +94,7 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::Tensor
STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const QuantType &quant_type,
schema::MetaGraphT *sub_graph) {
const QuantType &quant_type, schema::MetaGraphT *sub_graph) {
int idx = 0;
for (const auto &tflite_op : tflite_subgraph->operators) {
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
......@@ -107,7 +105,7 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
return RET_ERROR;
}
std::unique_ptr<schema::CNodeT> op = std::make_unique<schema::CNodeT>();
auto op = std::make_unique<schema::CNodeT>();
op->name = op_type + "-" + std::to_string(idx++);
op->quantType = quant_type;
MS_LOG(INFO) << "parse op: " << op->name.c_str();
......@@ -227,7 +225,7 @@ STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr<tflite::SubGraphT>
return RET_OK;
}
STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph) {
STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph) {
for (auto &op : sub_graph->nodes) {
if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
auto attr = op->primitive->value.AsDepthwiseConv2D();
......@@ -301,15 +299,10 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT* sub_graph)
return RET_OK;
}
MetaGraphT *TfliteModelParser::Parse(const std::string &model_file,
const std::string &weight_file,
const QuantType &quant_type) {
std::unique_ptr<schema::MetaGraphT> sub_graph = std::make_unique<schema::MetaGraphT>();
sub_graph->name = "MS_model converted by TF-Lite";
quantType = quant_type;
schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
// load graph
std::unique_ptr<tflite::ModelT> tflite_model = ReadTfliteModel(model_file.c_str());
auto tflite_model = ReadTfliteModel(model_file.c_str());
if (tflite_model == nullptr) {
MS_LOG(ERROR) << "read tflite model failed";
return nullptr;
......@@ -321,31 +314,38 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &model_file,
}
const auto &tflite_subgraph = tflite_model->subgraphs[0];
auto meta_graph = std::make_unique<schema::MetaGraphT>();
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "new meta graph failed";
return nullptr;
}
meta_graph->name = "MS_model converted by TF-Lite";
quantType = quant_type;
// convert op
if (ConvertOp(tflite_model, tflite_subgraph, quant_type, sub_graph.get()) != RET_OK) {
if (ConvertOp(tflite_model, tflite_subgraph, quant_type, meta_graph.get()) != RET_OK) {
MS_LOG(ERROR) << "parse op failed.";
return nullptr;
}
// convert tensor
if (ConvertTensor(tflite_subgraph, tflite_model->buffers, sub_graph.get()) != RET_OK) {
if (ConvertTensor(tflite_subgraph, tflite_model->buffers, meta_graph.get()) != RET_OK) {
MS_LOG(ERROR) << "convert tensor failed";
return nullptr;
}
// set graph input/output
if (GetGraphInfo(tflite_subgraph, sub_graph.get()) != RET_OK) {
if (GetGraphInfo(tflite_subgraph, meta_graph.get()) != RET_OK) {
MS_LOG(ERROR) << "convert tensors failed";
return nullptr;
}
// update for depthwiseConv
if (ConvertGroupDepthwiseOp(sub_graph.get()) != RET_OK) {
if (ConvertGroupDepthwiseOp(meta_graph.get()) != RET_OK) {
MS_LOG(ERROR) << "convert group depthwise conv failed";
return nullptr;
}
return sub_graph.release();
return meta_graph.release();
}
} // namespace lite
} // namespace mindspore
......@@ -41,7 +41,7 @@ class TfliteModelParser : public ModelParser {
~TfliteModelParser() override;
MetaGraphT *Parse(const std::string &model_file,
schema::MetaGraphT *ParseToFb(const std::string &model_file,
const std::string &weight_file,
const QuantType &quantType = QuantType_QUANT_NONE) override;
......@@ -78,6 +78,7 @@ class TfliteModelParser : public ModelParser {
std::map<std::string, schema::CNodeT *> opMap;
std::map<const tflite::OperatorT *, schema::CNodeT *> tfliteOpMap;
QuantType quantType = QuantType_QUANT_NONE;
char *tfliteModelBuf = nullptr;
};
} // namespace lite
} // namespace mindspore
......
......@@ -41,200 +41,173 @@ using std::vector;
namespace mindspore {
namespace lite {
namespace quant {
struct DivergInfo {
std::vector<float> histogram;
CNodePtr cnode;
int bin_num;
float interval = 0;
float max;
float min;
float best_T = 0.0f;
size_t bit_num;
int quant_max = 255;
int quant_min = 0;
std::string method_x = kMethodKL;
DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min, const std::string &method_x) {
this->method_x = method_x;
this->cnode = cnode;
this->bin_num = bins;
this->bit_num = bits;
histogram.resize(bin_num);
max = -FLT_MAX;
min = FLT_MAX;
this->quant_max = quant_max;
this->quant_min = quant_min;
std::fill(histogram.begin(), histogram.end(), 1.0e-7);
STATUS DivergInfo::RecordMaxValue(const std::vector<float> &datas) {
for (float data : datas) {
max = std::max(data, max);
min = std::min(data, min);
}
return RET_OK;
}
STATUS RecordMaxValue(const std::vector<float> &datas) {
for (float data : datas) {
max = std::max(data, max);
min = std::min(data, min);
void DivergInfo::UpdateInterval() {
auto max_value = std::max(fabs(this->max), fabs(this->min));
this->interval = max_value / static_cast<float>(bin_num);
}
STATUS DivergInfo::UpdateHistogram(const std::vector<float> &data) {
for (auto value : data) {
if (value == 0) {
continue;
}
return RET_OK;
int bin_index = std::min(static_cast<int>(std::fabs(value) / this->interval), bin_num - 1);
this->histogram[bin_index]++;
}
return RET_OK;
}
void UpdateInterval() {
auto max_value = std::max(fabs(this->max), fabs(this->min));
this->interval = max_value / static_cast<float>(bin_num);
void DivergInfo::DumpHistogram() {
MS_LOG(INFO) << "Print node " << cnode->fullname_with_scope() << " histogram";
for (float item : this->histogram) {
std::cout << item << " ";
}
std::cout << std::endl;
}
STATUS UpdateHistogram(const std::vector<float> &data) {
for (auto value : data) {
if (value == 0) {
continue;
}
int bin_index = std::min(static_cast<int>(std::fabs(value) / this->interval), bin_num - 1);
this->histogram[bin_index]++;
}
STATUS DivergInfo::ComputeThreshold() {
if (method_x == kMethodMaxMin) {
this->best_T = std::max(fabs(this->max), fabs(this->min));
MS_LOG(DEBUG) << "using MAX_MIN, T: " << this->best_T;
return RET_OK;
}
void DumpHistogram() {
MS_LOG(INFO) << "Print node " << cnode->fullname_with_scope() << " histogram";
for (float item : this->histogram) {
std::cout << item << " ";
}
std::cout << std::endl;
}
STATUS ComputeThreshold() {
if (method_x == kMethodMaxMin) {
this->best_T = std::max(fabs(this->max), fabs(this->min));
MS_LOG(DEBUG) << "using MAX_MIN, T: " << this->best_T;
return RET_OK;
constexpr int quant_bint_nums = 128;
int threshold = quant_bint_nums;
float min_kl = FLT_MAX;
float after_threshold_sum = std::accumulate(this->histogram.begin() + quant_bint_nums, this->histogram.end(), 0.0f);
for (int i = quant_bint_nums; i < this->bin_num; ++i) {
std::vector<float> quantized_histogram(quant_bint_nums, 0);
std::vector<float> reference_histogram(this->histogram.begin(), this->histogram.begin() + i);
std::vector<float> expanded_histogram(i, 0);
reference_histogram[i - 1] += after_threshold_sum;
after_threshold_sum -= this->histogram[i];
const float bin_interval = static_cast<float>(i) / static_cast<float>(quant_bint_nums);
// merge i bins to target bins
for (int j = 0; j < quant_bint_nums; ++j) {
const float start = j * bin_interval;
const float end = start + bin_interval;
const int left_upper = static_cast<int>(std::ceil(start));
if (left_upper > start) {
const double left_scale = left_upper - start;
quantized_histogram[j] += left_scale * this->histogram[left_upper - 1];
}
const int right_lower = static_cast<int>(std::floor(end));
if (right_lower < end) {
const double right_scale = end - right_lower;
quantized_histogram[j] += right_scale * this->histogram[right_lower];
}
std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower,
[&quantized_histogram, j](float item) { quantized_histogram[j] += item; });
}
constexpr int quant_bint_nums = 128;
int threshold = quant_bint_nums;
float min_kl = FLT_MAX;
float after_threshold_sum = std::accumulate(this->histogram.begin() + quant_bint_nums, this->histogram.end(), 0.0f);
for (int i = quant_bint_nums; i < this->bin_num; ++i) {
std::vector<float> quantized_histogram(quant_bint_nums, 0);
std::vector<float> reference_histogram(this->histogram.begin(), this->histogram.begin() + i);
std::vector<float> expanded_histogram(i, 0);
reference_histogram[i - 1] += after_threshold_sum;
after_threshold_sum -= this->histogram[i];
const float bin_interval = static_cast<float>(i) / static_cast<float>(quant_bint_nums);
// merge i bins to target bins
for (int j = 0; j < quant_bint_nums; ++j) {
const float start = j * bin_interval;
const float end = start + bin_interval;
const int left_upper = static_cast<int>(std::ceil(start));
if (left_upper > start) {
const double left_scale = left_upper - start;
quantized_histogram[j] += left_scale * this->histogram[left_upper - 1];
// expand target bins to i bins in order to calculate KL with reference_histogram
for (int j = 0; j < quant_bint_nums; ++j) {
const float start = j * bin_interval;
const float end = start + bin_interval;
float count = 0;
const int left_upper = static_cast<int>(std::ceil(start));
float left_scale = 0.0f;
if (left_upper > start) {
left_scale = left_upper - start;
if (this->histogram[left_upper - 1] != 0) {
count += left_scale;
}
const int right_lower = static_cast<int>(std::floor(end));
if (right_lower < end) {
const double right_scale = end - right_lower;
quantized_histogram[j] += right_scale * this->histogram[right_lower];
}
std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower,
[&quantized_histogram, j](float item) { quantized_histogram[j] += item; });
}
// expand target bins to i bins in order to calculate KL with reference_histogram
for (int j = 0; j < quant_bint_nums; ++j) {
const float start = j * bin_interval;
const float end = start + bin_interval;
float count = 0;
const int left_upper = static_cast<int>(std::ceil(start));
float left_scale = 0.0f;
if (left_upper > start) {
left_scale = left_upper - start;
if (this->histogram[left_upper - 1] != 0) {
count += left_scale;
}
}
const int right_lower = static_cast<int>(std::floor(end));
double right_scale = 0.0f;
if (right_lower < end) {
right_scale = end - right_lower;
if (this->histogram[right_lower] != 0) {
count += right_scale;
}
}
std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower,
[&count](float item) {
if (item != 0) {
count += 1;
}
});
if (count == 0) {
continue;
}
const float average_num = quantized_histogram[j] / count;
if (left_upper > start && this->histogram[left_upper - 1] != 0) {
expanded_histogram[left_upper - 1] += average_num * left_scale;
const int right_lower = static_cast<int>(std::floor(end));
double right_scale = 0.0f;
if (right_lower < end) {
right_scale = end - right_lower;
if (this->histogram[right_lower] != 0) {
count += right_scale;
}
if (right_lower < end && this->histogram[right_lower] != 0) {
expanded_histogram[right_lower] += average_num * right_scale;
}
std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower, [&count](float item) {
if (item != 0) {
count += 1;
}
for (int k = left_upper; k < right_lower; ++k) {
if (this->histogram[k] != 0) {
expanded_histogram[k] += average_num;
}
});
if (count == 0) {
continue;
}
const float average_num = quantized_histogram[j] / count;
if (left_upper > start && this->histogram[left_upper - 1] != 0) {
expanded_histogram[left_upper - 1] += average_num * left_scale;
}
if (right_lower < end && this->histogram[right_lower] != 0) {
expanded_histogram[right_lower] += average_num * right_scale;
}
for (int k = left_upper; k < right_lower; ++k) {
if (this->histogram[k] != 0) {
expanded_histogram[k] += average_num;
}
}
auto KLDivergence = [](std::vector<float> p, std::vector<float> q) {
auto sum = 0.0f;
std::for_each(p.begin(), p.end(), [&sum](float item) { sum += item; });
std::for_each(p.begin(), p.end(), [sum](float &item) { item /= sum; });
sum = 0.0f;
std::for_each(q.begin(), q.end(), [&sum](float item) { sum += item; });
std::for_each(q.begin(), q.end(), [sum](float &item) { item /= sum; });
float result = 0.0f;
const int size = p.size();
for (int i = 0; i < size; ++i) {
if (p[i] != 0) {
if (q[i] == 0) {
result += 1.0f;
} else {
result += (p[i] * std::log((p[i]) / (q[i])));
}
}
auto KLDivergence = [](std::vector<float> p, std::vector<float> q) {
auto sum = 0.0f;
std::for_each(p.begin(), p.end(), [&sum](float item) { sum += item; });
std::for_each(p.begin(), p.end(), [sum](float &item) { item /= sum; });
sum = 0.0f;
std::for_each(q.begin(), q.end(), [&sum](float item) { sum += item; });
std::for_each(q.begin(), q.end(), [sum](float &item) { item /= sum; });
float result = 0.0f;
const int size = p.size();
for (int i = 0; i < size; ++i) {
if (p[i] != 0) {
if (q[i] == 0) {
result += 1.0f;
} else {
result += (p[i] * std::log((p[i]) / (q[i])));
}
}
return result;
};
const float kl = KLDivergence(reference_histogram, expanded_histogram);
if (kl < min_kl) {
min_kl = kl;
threshold = i;
}
return result;
};
const float kl = KLDivergence(reference_histogram, expanded_histogram);
if (kl < min_kl) {
min_kl = kl;
threshold = i;
}
this->best_T = (static_cast<float>(threshold) + 0.5f) * this->interval;
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " Best threshold bin index: " << threshold << " T: " << best_T
<< " max: " << std::max(fabs(this->max), fabs(this->min));
return RET_OK;
}
this->best_T = (static_cast<float>(threshold) + 0.5f) * this->interval;
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " Best threshold bin index: " << threshold << " T: " << best_T
<< " max: " << std::max(fabs(this->max), fabs(this->min));
return RET_OK;
}
std::pair<CNodePtr, float> GetScale() {
float max_value = this->best_T;
float min_value = -max_value;
std::pair<CNodePtr, float> DivergInfo::GetScale() {
float max_value = this->best_T;
float min_value = -max_value;
MS_ASSERT(quant_max - quant_min != 0);
float scale = (max_value - min_value) / (quant_max - quant_min);
MS_ASSERT(scale != 0);
return std::make_pair(this->cnode, scale);
}
MS_ASSERT(quant_max - quant_min != 0);
float scale = (max_value - min_value) / (quant_max - quant_min);
MS_ASSERT(scale != 0);
return std::make_pair(this->cnode, scale);
}
std::pair<CNodePtr, int32_t> GetZeropoint() {
int zero_point = 0;
if (quant_min == 0 && quant_max == 255) {
zero_point = 128;
} else if (quant_min == -127 && quant_max == 127) {
zero_point = 0;
} else {
MS_LOG(WARNING) << "unexpectd quant range, quant_min: " << quant_min << " quant_max: " << quant_max;
}
return std::make_pair(this->cnode, zero_point);
std::pair<CNodePtr, int32_t> DivergInfo::GetZeropoint() {
int zero_point = 0;
if (quant_min == 0 && quant_max == 255) {
zero_point = 128;
} else if (quant_min == -127 && quant_max == 127) {
zero_point = 0;
} else {
MS_LOG(WARNING) << "unexpectd quant range, quant_min: " << quant_min << " quant_max: " << quant_max;
}
};
return std::make_pair(this->cnode, zero_point);
}
std::unordered_map<CNodePtr, float> Calibrator::GetScale(
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
std::unordered_map<CNodePtr, float> result;
......@@ -359,7 +332,7 @@ STATUS Calibrator::AddQuantizedOp(CNodePtr node) {
void Calibrator::AddImage(const string file) {
auto exist = [](const string file) {
struct stat buf{};
struct stat buf {};
return stat(file.c_str(), &buf) == 0;
};
if (exist(file)) {
......
......@@ -23,6 +23,7 @@
#include <vector>
#include <cfloat>
#include <map>
#include <utility>
#include "src/lite_session.h"
#include "tools/converter/quantizer/quantizer.h"
#include "tools/converter/converter.h"
......@@ -90,13 +91,51 @@ class PostTrainingQuantizer : public Quantizer {
STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>);
STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>);
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel,
bool depthwise);
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel, bool depthwise);
STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitive_c);
};
struct DivergInfo;
struct DivergInfo {
std::vector<float> histogram;
CNodePtr cnode;
int bin_num;
float interval = 0;
float max;
float min;
float best_T = 0.0f;
size_t bit_num;
int quant_max = 255;
int quant_min = 0;
std::string method_x = kMethodKL;
DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min, const std::string &method_x) {
this->method_x = method_x;
this->cnode = cnode;
this->bin_num = bins;
this->bit_num = bits;
histogram.resize(bin_num);
max = -FLT_MAX;
min = FLT_MAX;
this->quant_max = quant_max;
this->quant_min = quant_min;
std::fill(histogram.begin(), histogram.end(), 1.0e-7);
}
STATUS RecordMaxValue(const std::vector<float> &datas);
void UpdateInterval();
STATUS UpdateHistogram(const std::vector<float> &data);
void DumpHistogram();
STATUS ComputeThreshold();
std::pair<CNodePtr, float> GetScale();
std::pair<CNodePtr, int32_t> GetZeropoint();
};
class Calibrator {
public:
......@@ -123,7 +162,7 @@ class Calibrator {
STATUS UpdateDivergInverval(std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
STATUS UpdateDataFrequency(const std::string& op_name, const std::vector<float>& data,
STATUS UpdateDataFrequency(const std::string &op_name, const std::vector<float> &data,
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
void Dump();
......
......@@ -13,11 +13,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/constant_folding_fusion.h"
#include <memory>
#include <set>
#include <vector>
#include "schema/inner/model_generated.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "src/kernel_registry.h"
......@@ -30,7 +30,7 @@ using mindspore::lite::PrimitiveC;
using mindspore::lite::tensor::Tensor;
namespace mindspore::opt {
namespace {
const std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
MS_ASSERT(CNode != nullptr);
auto tmp_meta_graph = std::make_unique<schema::MetaGraphT>();
auto tmp_fb_node = std::make_unique<schema::CNodeT>();
......@@ -48,11 +48,11 @@ const std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
}
auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t);
// when tensorT as graph input
if (lite_tensor_size == 0) {
if (lite_tensor_size <= 0) {
delete lite_tensor;
return input_tensors;
}
auto tensor_data = new (std::nothrow) char[lite_tensor_size / sizeof(char)];
auto tensor_data = reinterpret_cast<uint8_t *>(malloc(lite_tensor_size / sizeof(char)));
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "tensor_data is nullptr";
delete lite_tensor;
......@@ -61,16 +61,16 @@ const std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
auto ret = memcpy_s(tensor_data, lite_tensor_size, tensorT->data.data(), lite_tensor_size);
if (ret != EOK) {
delete lite_tensor;
delete tensor_data;
delete[](tensor_data);
MS_LOG(EXCEPTION) << "memcpy error: " << ret;
return input_tensors;
}
lite_tensor->SetData(tensor_data);
input_tensors.emplace_back(lite_tensor);
}
return input_tensors;
}
const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
auto parameter = func_graph->add_parameter();
std::vector<int> shape(tensor->shape());
auto type_id = static_cast<TypeId>(tensor->data_type());
......@@ -102,17 +102,12 @@ const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *ten
parameter->set_default_param(param_value);
return parameter;
}
kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs,
kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs, OpParameter *parameter,
mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(nullptr != lite_primitive);
auto data_type = inputs.front()->data_type();
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, (schema::PrimitiveType)primitive->Type()};
lite::Context context;
auto parameter = kernel::PopulateParameter(primitive);
if (parameter == nullptr) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << (schema::PrimitiveType)primitive->Type();
return nullptr;
}
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
if (creator != nullptr) {
auto lite_kernel = creator(inputs, outputs, parameter, &context, desc, primitive);
......@@ -121,16 +116,19 @@ kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tens
return nullptr;
}
} // namespace
void FreeInputTensor(std::vector<Tensor *> *input_tensor) {
MS_ASSERT(input_tensor != nullptr);
for (size_t i = 0; i < input_tensor->size(); i++) {
if ((*input_tensor)[i] == nullptr) {
continue;
void FreeTensors(std::vector<Tensor *> *input_tensor, std::vector<Tensor *> *output_tensor) {
if (input_tensor != nullptr) {
for (size_t i = 0; i < input_tensor->size(); i++) {
delete (*input_tensor)[i];
(*input_tensor)[i] = nullptr;
}
}
if (output_tensor != nullptr) {
for (size_t i = 0; i < output_tensor->size(); i++) {
delete (*output_tensor)[i];
(*output_tensor)[i] = nullptr;
}
delete (*input_tensor)[i];
(*input_tensor)[i] = nullptr;
}
return;
}
const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
......@@ -148,7 +146,7 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
auto input_cnode = input_node->cast<CNodePtr>();
auto input_tensors = GetCNodeInputTensors(input_cnode);
if (input_tensors.empty() || input_tensors.size() != input_cnode->inputs().size() - 1) {
FreeInputTensor(&input_tensors);
FreeTensors(&input_tensors, nullptr);
continue;
}
MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope();
......@@ -157,39 +155,47 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
auto lite_primitive = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
if (lite_primitive == nullptr) {
MS_LOG(ERROR) << "lite_primitive is nullptr";
FreeTensors(&input_tensors, &output_tensors);
return nullptr;
}
// here, input_tensor's format need to be transposed nhwc according to fmkType,
// but for the time being, we only transpose the tensor with 0/1/2/3D.
// Others should be added in future.
for (size_t j = 0; j < input_tensors.size(); ++j) {
input_tensors[j]->SetFormat(schema::Format_NHWC);
if (input_tensors[j]->shape().size() == 4) {
MS_LOG(WARNING) << "init input_tensor format to nhwc";
}
input_tensors[j]->SetFormat(schema::Format_NHWC);
if (input_tensors[j]->shape().size() == 4) {
MS_LOG(WARNING) << "init input_tensor format to nhwc";
}
}
lite_primitive->InferShape(input_tensors, output_tensors);
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive.get());
auto parameter = kernel::PopulateParameter(lite_primitive.get());
if (parameter == nullptr) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
<< schema::EnumNamePrimitiveType((schema::PrimitiveType)(lite_primitive->Type()));
return nullptr;
}
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, lite_primitive.get());
if (lite_kernel == nullptr) {
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
FreeInputTensor(&input_tensors);
FreeTensors(&input_tensors, &output_tensors);
return nullptr;
}
auto ret = lite_kernel->Run();
if (0 != ret) {
FreeInputTensor(&input_tensors);
FreeTensors(&input_tensors, &output_tensors);
MS_LOG(ERROR) << "run kernel failed, name: " << lite_kernel->name();
return nullptr;
}
auto new_parameter = CreateNewParamter(func_graph, output_tensors.front());
if (new_parameter == nullptr) {
FreeInputTensor(&input_tensors);
FreeTensors(&input_tensors, &output_tensors);
MS_LOG(ERROR) << "CreateNewParamter failed, name: " << lite_kernel->name();
return nullptr;
}
new_parameter->set_name(input_node->fullname_with_scope());
any_node->set_input(i, new_parameter);
FreeInputTensor(&input_tensors);
FreeTensors(&input_tensors, &output_tensors);
delete (lite_kernel);
}
}
return any_node;
......
......@@ -17,6 +17,10 @@
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_
#include "schema/inner/model_generated.h"
#include "src/ir/tensor.h"
#include "src/lite_kernel.h"
#include "nnacl/op_base.h"
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
......@@ -30,4 +34,3 @@ class ConstFoldPass : public PatternProcessPass {
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册