提交 04371f6d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4364 add aware quant converter

Merge pull request !4364 from cjh9368/aware_quant
......@@ -188,7 +188,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
// add quant param
node->quantType = primitiveT_value->GetQuantType();
if (node->quantType == schema::QuantType_PostTraining) {
if (node->quantType == schema::QuantType_PostTraining || node->quantType == schema::QuantType_AwareTrainning) {
MS_LOG(INFO) << "node: " << node->name << " add QuantParam";
// activation
auto input_quant_params = primitiveT_value->GetInputQuantParams();
......
......@@ -60,6 +60,17 @@ void AnfImporterFromMetaGraphT::ConverterConstTensor() {
param_value->set_tensor_addr(tensor_data);
param_value->set_tensor_size(size);
}
if (tensor->quantParams.size() > 0) {
std::unique_ptr<AnfQuantParam> quantParam = std::make_unique<AnfQuantParam>();
quantParam->scale = tensor->quantParams[0]->scale;
quantParam->zeroPoint = tensor->quantParams[0]->zeroPoint;
quantParam->min = tensor->quantParams[0]->min;
quantParam->max = tensor->quantParams[0]->max;
quantParam->narrowRange = tensor->quantParams[0]->narrowRange;
quantParam->numBits = tensor->quantParams[0]->numBits;
quantParam->inited = tensor->quantParams[0]->inited;
param_value->set_quant_param(quantParam);
}
parameter->set_default_param(param_value);
AddNode(i, parameter);
}
......@@ -77,6 +88,16 @@ int AnfImporterFromMetaGraphT::ConverterCNode() {
flag = true;
}
auto primTValue = std::make_shared<PrimitiveTValue>(cNode->primitive.release());
// add quant parameter
if (cNode->quantType == schema::QuantType_AwareTrainning || cNode->quantType == schema::QuantType_PostTraining) {
primTValue->SetQuantType(cNode->quantType);
for (int index : cNode->inputIndex) {
primTValue->AddInputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0]));
}
for (int index : cNode->outputIndex) {
primTValue->AddOutputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0]));
}
}
cNode->primitive = nullptr;
auto value_node = NewValueNode(primTValue);
......
......@@ -28,7 +28,7 @@
namespace mindspore {
namespace lite {
OpDefCopyer GetSimpleOpCopyer() {
return [](std::unique_ptr<CNodeT> &inCNode) -> std::unique_ptr<CNodeT> {
return [](CNodeT *inCNode) -> std::unique_ptr<CNodeT> {
std::unique_ptr<CNodeT> newCNode(new CNodeT);
newCNode->name = inCNode->name;
......@@ -421,9 +421,13 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
}
preTensor->refCount = 0;
preTensor->data.clear();
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
}
graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
auto toAddNode = opDefCopyer(toAddNodeIn);
auto toAddNode = opDefCopyer(toAddNodeIn.get());
if (toAddNode == nullptr) {
MS_LOG(ERROR) << "copy toAddNodeIn failed";
*errorCode = RET_NULL_PTR;
......@@ -456,9 +460,13 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
// MS_LOG(ERROR)("Copy TensorT failed");
return graphT->nodes.end();
}
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
}
graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
auto toAddNode = opDefCopyer(toAddNodeIn);
auto toAddNode = opDefCopyer(toAddNodeIn.get());
if (toAddNode == nullptr) {
// MS_LOG(ERROR)("copy toAddNodeIn failed");
*errorCode = RET_NULL_PTR;
......@@ -505,9 +513,13 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
*errorCode = RET_NULL_PTR;
return graphT->nodes.end();
}
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
}
graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
auto toAddNode = opDefCopyer(toAddNodeIn);
auto toAddNode = opDefCopyer(toAddNodeIn.get());
if (toAddNode == nullptr) {
// MS_LOG(ERROR)("copy toAddNodeIn failed");
*errorCode = RET_NULL_PTR;
......@@ -540,9 +552,13 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
*errorCode = RET_NULL_PTR;
return graphT->nodes.end();
}
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
}
graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
auto toAddNode = opDefCopyer(toAddNodeIn);
auto toAddNode = opDefCopyer(toAddNodeIn.get());
if (toAddNode == nullptr) {
// MS_LOG(ERROR)("copy toAddNodeIn failed");
*errorCode = RET_NULL_PTR;
......
......@@ -36,7 +36,7 @@ enum InsertPlace { kBefore, kAfter };
using NodeIter = std::vector<std::unique_ptr<schema::CNodeT>>::iterator;
using OpDefCopyer = std::function<std::unique_ptr<schema::CNodeT>(std::unique_ptr<schema::CNodeT> &)>;
using OpDefCopyer = std::function<std::unique_ptr<schema::CNodeT> (schema::CNodeT *)>;
OpDefCopyer GetSimpleOpCopyer();
......
......@@ -19,8 +19,29 @@
#include "tools/common/tensor_util.h"
#include "tools/common/graph_util.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT> &tensor) {
MS_ASSERT(tensor != nullptr);
auto &quantParams = tensor->quantParams;
if (!quantParams.empty()) {
return std::move(CopyQuantParamT(quantParams.front()));
} else {
return nullptr;
}
}
std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schema::QuantParamT> &srcQuantParam) {
MS_ASSERT(srcQuantParam != nullptr);
std::unique_ptr<schema::QuantParamT> dstQuantParam = std::make_unique<schema::QuantParamT>();
dstQuantParam->inited = srcQuantParam->inited;
dstQuantParam->scale = srcQuantParam->scale;
dstQuantParam->zeroPoint = srcQuantParam->zeroPoint;
dstQuantParam->min = srcQuantParam->min;
dstQuantParam->max = srcQuantParam->max;
dstQuantParam->narrowRange = srcQuantParam->narrowRange;
dstQuantParam->numBits = srcQuantParam->numBits;
return std::move(dstQuantParam);
}
std::unique_ptr<QuantParamT> CopyQuantParamArrayT(const std::unique_ptr<QuantParamT> &srcQuantParamArray) {
MS_ASSERT(srcQuantParamArray != nullptr);
auto dstQuantParamArrayT = std::unique_ptr<QuantParamT>(new (std::nothrow) QuantParamT());
......@@ -164,6 +185,9 @@ std::unique_ptr<TensorT> CopyTensorDefT(const std::unique_ptr<TensorT> &oldTenso
newTensor->refCount = oldTensor->refCount;
newTensor->nodeType = oldTensor->nodeType;
newTensor->data = oldTensor->data;
if (!oldTensor->quantParams.empty()) {
newTensor->quantParams.emplace_back(std::move(GetTensorQuantParam(oldTensor)));
}
return std::move(newTensor);
}
......@@ -186,6 +210,4 @@ size_t GetShapeSize(const std::vector<int32_t> &shape) {
}
return shapeSize;
}
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite
......@@ -38,6 +38,9 @@ using schema::FusedBatchNormT;
using schema::Format_NCHW;
using schema::Format_NHWC;
using STATUS = int;
std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT> &tensor);
size_t GetElementSize(const TensorT &tensor);
size_t GetElementSize(const TypeId &dataType);
......@@ -50,6 +53,8 @@ std::unique_ptr<TensorT> CopyTensorDefT(const std::unique_ptr<TensorT> &);
size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx);
std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schema::QuantParamT> &srcQuantParam);
std::unique_ptr<schema::QuantParamT> \
CopyQuantParamArrayT(const std::unique_ptr<schema::QuantParamT> &srcQuantParamArray);
......
......@@ -101,6 +101,7 @@ target_link_libraries(converter_lite PRIVATE
node_mid
graph_pass_mid
fusion_mid
quantizer_mid
protobuf
quantizer_mid
pthread
......
......@@ -77,7 +77,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
MS_ASSERT(nullptr != modelParser);
const std::string modelFile = flag->modelFile;
const std::string weightFile = flag->weightFile;
auto meta_graph = modelParser->Parse(modelFile, weightFile);
auto meta_graph = modelParser->Parse(modelFile, weightFile, flag->quantType);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Parse to metaGraph return nullptr";
return nullptr;
......@@ -118,6 +118,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
// transform
transform->SetGraphDef(meta_graph);
transform->CreateQuantizer(flag);
auto status = transform->Transform(*flag);
if (status != 0) {
MS_LOG(ERROR) << "FBTransform model failed " << status;
......@@ -125,6 +126,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
}
return meta_graph;
}
void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags) {
auto type = flags->quantType;
switch (type) {
......@@ -132,17 +134,18 @@ void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *
// mQuantizer.reset(new AwareQuantizer(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean));
break;
}
case mindspore::schema::QuantType_WeightQuant: {
MS_LOG(INFO) << "create WeightQuantizer!";
mQuantizer.reset(
new quant::WeightQuantizer(funcGraph, flags->quantSize, flags->convWeightQuantChannelThreshold, flags->bitNum));
break;
}
case mindspore::schema::QuantType_PostTraining: {
MS_LOG(INFO) << "create PostTrainningQuantizer!";
mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8));
break;
}
// case mindspore::schema::QuantType_WeightQuant: {
// MS_LOG(INFO) << "create WeightQuantizer!";
// mQuantizer.reset(
// new quant::WeightQuantizer(funcGraph, flags->quantSize, flags->convWeightQuantChannelThreshold,
// flags->bitNum));
// break;
// }
// case mindspore::schema::QuantType_PostTraining: {
// MS_LOG(INFO) << "create PostTrainningQuantizer!";
// mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8));
// break;
// }
case mindspore::schema::QuantType_QUANT_NONE:
MS_LOG(INFO) << "Not do quantization for model!";
break;
......
......@@ -14,8 +14,12 @@
* limitations under the License.
*/
#include <string>
#include "tools/converter/converter_flags.h"
#include <regex>
#include <string>
#include "ir/dtype/type_id.h"
namespace mindspore {
namespace lite {
......@@ -70,9 +74,11 @@ int Flags::Init(int argc, const char **argv) {
return 1;
}
if (this->inputInferenceTypeIn == "FLOAT") {
this->inputInferenceType = 0;
this->inputInferenceType = TypeId::kNumberTypeFloat;
} else if (this->inputInferenceTypeIn == "UINT8") {
this->inputInferenceType = 1;
this->inputInferenceType = TypeId::kNumberTypeUInt8;
} else if (this->inputInferenceTypeIn == "INT8") {
this->inputInferenceType = TypeId::kNumberTypeInt8;
} else {
std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s", this->inputInferenceTypeIn.c_str();
return 1;
......
......@@ -19,6 +19,7 @@
#include <string>
#include "tools/common/flag_parser.h"
#include "ir/dtype/type_id.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
......@@ -66,7 +67,7 @@ class Flags : public virtual mindspore::lite::FlagParser {
// used for parse aware trainning
std::string inputInferenceTypeIn;
// mindspore::predict::DataType inputInferenceType = DataType_DT_FLOAT;
int inputInferenceType = 0;
TypeId inputInferenceType = TypeId::kNumberTypeFloat;
std::string stdDev;
std::string mean;
// used for post-trainning-weight
......
......@@ -16,11 +16,13 @@
#include "tools/converter/graphdef_transform.h"
#include <iostream>
#include <memory>
#include <string>
#include "schema/model_generated.h"
#include "utils/log_adapter.h"
#include "src/common/op_utils.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h"
#include "tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h"
......@@ -28,7 +30,7 @@
#include "tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h"
// #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h"
// #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
// #include "tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h"
//
// #include "tools/converter/legacy_optimizer/const_fold/add_const_fold_pass.h"
......@@ -52,18 +54,45 @@
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
#include "tools/converter/quantizer/aware_quantizer.h"
#include "tools/converter/converter.h"
using std::string;
namespace mindspore {
namespace lite {
namespace mindspore::lite {
GraphDefTransform::GraphDefTransform() = default;
GraphDefTransform::~GraphDefTransform() = default;
void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; }
void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) {
auto type = flags->quantType;
switch (type) {
case QuantType::QuantType_AwareTrainning: {
MS_LOG(INFO) << "create AwareTrainningQuantizer!";
fbQuantizer =
std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean);
break;
}
// case QuantType::QuantType_WeightQuant: {
// MS_LOGI("create WeightQuantizer!");
// mQuantizer.reset(new WeightQuantizer(graphDefT, flags->quantSize));
// break;
// }
// case QuantType_PostTraining: {
// MS_LOGI("create PostTrainningQuantizer!");
// mQuantizer.reset(new PostTrainingQuantizer(graphDefT, flags->configFile));
// break;
// }
// case QuantType::QuantType_QUANT_NONE:
// MS_LOGD("Not do quantization for model!");
// break;
default:
// MS_LOGI("will support quantizer type %s in the future!", flags->quantTypeIn.c_str());
break;
}
}
int GraphDefTransform::Transform(const converter::Flags &ctx) {
STATUS status;
// // constant folding
......@@ -133,6 +162,53 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
}
{
Optimizer unusedOpRemoveOptimizer;
unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass());
unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass());
status = unusedOpRemoveOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed";
return status;
}
}
// topological sorting
{
Optimizer topologicalOptimizer;
topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
status = topologicalOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
return status;
}
}
// generate and infer quant parameters
{
if (mQuantizer != nullptr) {
Optimizer topologicalOptimizer;
topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
status = topologicalOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
return status;
}
if (!(this->graphDefT->fmkType == converter::FmkType_TF &&
this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTrainning)) {
status = mQuantizer->GenerateQuantParam();
if (status != RET_OK) {
MS_LOG(ERROR) << "GenerateQuantParam failed";
return status;
}
status = mQuantizer->DetermineNodeQuantType();
if (status != RET_OK) {
MS_LOG(ERROR) << "DetermineNodeQuant failed";
}
}
}
}
// format transform
if (ctx.formatTrans) {
Optimizer formatTransOptimizer;
......@@ -156,13 +232,30 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
}
{
Optimizer unusedOpRemoveOptimizer;
unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass());
unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass());
status = unusedOpRemoveOptimizer.Run(graphDefT);
// do quantization
if (fbQuantizer != nullptr) {
status = fbQuantizer->DoQuantize();
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantize failed!";
return status;
}
}
// insert quantNode and deQuantNode
if (ctx.quantType == QuantType_AwareTrainning) {
Optimizer quantNodeOptimizer;
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
if (dTypeTransPass == nullptr) {
MS_LOG(ERROR) << "new dTypeTransPass failed";
return RET_ERROR;
}
dTypeTransPass->SetInputDataDType(ctx.inputInferenceType);
quantNodeOptimizer.AddPass(dTypeTransPass);
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
status = quantNodeOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
return status;
}
}
......@@ -178,6 +271,4 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite
......@@ -17,8 +17,9 @@
#ifndef MS_GRAPHDEF_TRANSFORM_H
#define MS_GRAPHDEF_TRANSFORM_H
#include <memory>
#include "tools/converter/optimizer.h"
// #include "quantizer/quantizer.h"
#include "tools/converter/quantizer/quantizer.h"
#include "schema/inner/model_generated.h"
#include "tools/common/storage.h"
#include "tools/converter/converter_flags.h"
......@@ -42,7 +43,8 @@ class GraphDefTransform {
schema::MetaGraphT *graphDefT = nullptr;
Optimizer *optimizer = nullptr;
// std::unique_ptr<Quantizer> mQuantizer;
std::unique_ptr<quant::Quantizer> mQuantizer;
std::unique_ptr<quant::FbQuantizer> fbQuantizer;
};
} // namespace lite
} // namespace mindspore
......
......@@ -53,7 +53,7 @@ class MatMulBiasAddFusionPass : public FusionPass {
bool transB = false;
size_t id = 0;
OpDefCopyer TransposeOpCopyer = [](const std::unique_ptr<CNodeT> &inOpDef) -> std::unique_ptr<CNodeT> {
OpDefCopyer TransposeOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr<CNodeT> {
std::unique_ptr<CNodeT> newOpDef(new (std::nothrow) CNodeT);
if (newOpDef == nullptr) {
MS_LOG(ERROR) << "new OpDefT failed";
......
add_library(graph_pass_mid OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/format_trans_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/dtype_trans_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h"
#include <string>
#include "tools/common/converter_op_utils.h"
#include "tools/common/node_util.h"
#include "src/common/common.h"
#include "src/common/utils.h"
namespace mindspore {
namespace lite {
#define kMinInputNum 1
#define kOutputNum 1
STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
auto status = DoModelInputDTypeTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoModelInputDTypeTrans error: " << status;
return status;
}
status = DoModelOutputDTypeTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status;
return status;
}
status = DoNodeInoutDTypeTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status;
return status;
}
return RET_OK;
}
STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
// modify inputTensor first
auto &graphInIdxes = graph->inputIndex;
for (auto graphInIdx : graphInIdxes) {
MS_ASSERT(graph->allTensors.size() > graphInIdx);
auto &graphInTensor = graph->allTensors.at(graphInIdx);
graphInTensor->dataType = TypeId::kNumberTypeUInt8;
}
if (this->inputDataDType == TypeId::kNumberTypeInt8) {
return RET_OK;
}
if (this->inputDataDType != TypeId::kNumberTypeFloat && this->inputDataDType != TypeId::kNumberTypeUInt8) {
MS_LOG(ERROR) << "Invalid inputDataType: " << this->inputDataDType;
return RET_ERROR;
}
// insert fp2int8 node
for (auto graphInIdx : graphInIdxes) {
MS_ASSERT(graphInIdx < graph->allTensors.size());
auto &tensor = graph->allTensors.at(graphInIdx);
if (tensor->dims.size() != kNHWCDimNumber) {
continue;
}
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
auto nodeName = node->name;
for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) {
if (node->inputIndex.at(inputIndexIdx) == graphInIdx) {
STATUS status = RET_OK;
// insert dtype cast node between input tensor and input node
if (inputDataDType == TypeId::kNumberTypeFloat) {
iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, kFP32ToInt8, &status);
} else {
iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, kUInt8ToInt8, &status);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertDTypeTransNode before " << nodeName.c_str() << " failed";
return status;
}
}
}
}
}
return RET_OK;
}
STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
if (inputDataDType == TypeId::kNumberTypeInt8) {
return RET_OK;
}
MS_ASSERT(inputDataDType == TypeId::kNumberTypeFloat);
auto &graphOutIdxes = graph->outputIndex;
for (auto graphOutIdx : graphOutIdxes) {
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
auto nodeName = node->name;
MS_ASSERT(node != nullptr);
for (size_t outputIndexIdx = 0; outputIndexIdx < node->outputIndex.size(); outputIndexIdx++) {
if (node->outputIndex.at(outputIndexIdx) == graphOutIdx) {
// insert transNode
STATUS status = RET_OK;
if (inputDataDType == TypeId::kNumberTypeFloat) {
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status);
} else {
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToUInt8, &status);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed";
return status;
}
break;
}
}
}
}
return RET_OK;
}
STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
// insert transNode before and after existNode
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTrainning) {
continue;
}
auto &node = *iter;
if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) {
continue;
}
bool needInsertPost = true;
if (GetCNodeTType(**iter) == PrimitiveType_Shape) {
needInsertPost = false;
}
auto nodeName = node->name;
if (node->inputIndex.size() < kMinInputNum) {
MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least";
return RET_ERROR;
}
STATUS status;
// insert pre
for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) {
MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i));
auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i));
auto &graphInIdxes = graph->inputIndex;
if (preTensor->nodeType == NodeType_ValueNode && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) {
continue;
}
iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed";
return RET_ERROR;
}
}
if (needInsertPost) {
for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) {
iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed";
return RET_ERROR;
}
}
}
(*iter)->quantType = QuantType_QUANT_NONE;
}
return RET_OK;
}
NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place,
size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode) {
MS_ASSERT((*existNodeIter) != nullptr);
auto existNodeName = (*existNodeIter)->name;
std::string tileName;
if (place == kBefore) {
tileName = existNodeName + "_pre";
} else {
tileName = existNodeName + "_post";
}
auto transNode = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT);
if (transNode == nullptr) {
MS_LOG(ERROR) << "new TransNode failed";
*errorCode = RET_ERROR;
return graph->nodes.end();
}
auto quantDTypeCastParam = new (std::nothrow) QuantDTypeCastT;
if (quantDTypeCastParam == nullptr) {
MS_LOG(ERROR) << "new quantDTypeCastParam failed";
*errorCode = RET_ERROR;
return graph->nodes.end();
}
transNode->primitive = std::make_unique<schema::PrimitiveT>();
transNode->primitive->value.value = quantDTypeCastParam;
transNode->primitive->value.type = PrimitiveType_QuantDTypeCast;
transNode->quantType = QuantType_AwareTrainning;
if (nodeType == kInt8ToFP32) {
quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8;
quantDTypeCastParam->dstT = TypeId::kNumberTypeFloat32;
transNode->name = "int8toft32_" + tileName + std::to_string(id++);
} else if (nodeType == kFP32ToInt8) {
quantDTypeCastParam->srcT = TypeId::kNumberTypeFloat32;
quantDTypeCastParam->dstT = TypeId::kNumberTypeInt8;
transNode->name = "ft32toint8_" + tileName + std::to_string(id++);
} else if (nodeType == kUInt8ToInt8) {
quantDTypeCastParam->srcT = TypeId::kNumberTypeUInt8;
quantDTypeCastParam->dstT = TypeId::kNumberTypeInt8;
transNode->name = "uint8toint8_" + tileName + std::to_string(id++);
} else if (nodeType == kInt8ToUInt8) {
quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8;
quantDTypeCastParam->dstT = TypeId::kNumberTypeUInt8;
transNode->name = "int8touint8_" + tileName + std::to_string(id++);
}
transNode->primitive->value.value = quantDTypeCastParam;
return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode, castOpCopyer);
}
void DTypeTransPass::SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; }
} // namespace lite
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_PREDICT_DTYPE_TRANS_PASS_H
#define MINDSPORE_PREDICT_DTYPE_TRANS_PASS_H
#include <memory>
#include <utility>
#include "tools/converter/optimizer.h"
#include "tools/common/graph_util.h"
#include "tools/converter/converter_flags.h"
#include "tools/common/tensor_util.h"
namespace mindspore {
namespace lite {
enum DTypeTransNodeType { kInt8ToFP32, kFP32ToInt8, kUInt8ToInt8, kInt8ToUInt8 };
class DTypeTransPass : public GraphPass {
public:
DTypeTransPass() : id(0) {}
~DTypeTransPass() override = default;
STATUS Run(schema::MetaGraphT *graph) override;
void SetInputDataDType(TypeId dataType);
private:
STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph);
STATUS DoModelOutputDTypeTrans(schema::MetaGraphT *graph);
STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph);
NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx,
DTypeTransNodeType nodeType, STATUS *errorCode);
private:
size_t id;
TypeId inputDataDType = TypeId::kNumberTypeFloat;
OpDefCopyer castOpCopyer = [](schema::CNodeT *inCNode) -> std::unique_ptr<schema::CNodeT> {
std::unique_ptr<schema::CNodeT> newCNode(new (std::nothrow) schema::CNodeT);
if (newCNode == nullptr) {
MS_LOG(ERROR) << "new CNodeT failed";
return nullptr;
}
newCNode->name = inCNode->name;
newCNode->quantType = inCNode->quantType;
newCNode->primitive = std::make_unique<schema::PrimitiveT>();
newCNode->primitive->value.type = inCNode->primitive->value.type;
auto oldQuantDTypeCastParam = inCNode->primitive->value.AsQuantDTypeCast();
auto QuantDTypeCastParam = new (std::nothrow) QuantDTypeCastT;
if (QuantDTypeCastParam == nullptr) {
MS_LOG(ERROR) << "new QuantDTypeCast failed";
return nullptr;
}
QuantDTypeCastParam->srcT = oldQuantDTypeCastParam->srcT;
QuantDTypeCastParam->dstT = oldQuantDTypeCastParam->dstT;
newCNode->primitive->value.value = QuantDTypeCastParam;
return std::move(newCNode);
};
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_PREDICT_DTYPE_TRANS_PASS_H
......@@ -209,6 +209,9 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
return 0;
}
// inference needed filterFormat:
// conv deconv depth dedepth
// uint8 KHWC KHWC KHWC KHWC
int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
MS_ASSERT(graphNode != nullptr);
auto &subGraph = graphNode->subGraph;
......@@ -227,7 +230,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
auto &weightTensor = subGraph->allTensors[weightIndex];
MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT
STATUS status = RET_OK;
if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK
if (opType == schema::PrimitiveType_Conv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_KCHW) { // from caffe
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format
......@@ -236,58 +239,51 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
} else {
MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format
<< weightTensor->dataType;
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK);
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
}
} else if (weightTensor->format == schema::Format_KHWC) { // from onnx
return RET_OK;
// if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
// status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK);
// } else {
// status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK);
// }
} else if (weightTensor->format == schema::Format_HWCK) { // from tf
return 0;
} else {
} else if (weightTensor->format != schema::Format_KHWC) {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1;
}
if (status == 0) {
node->primitive->value.AsConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_HWCK;
weightTensor->format = schema::Format_KHWC;
} else {
MS_LOG(WARNING) << "TransFilter %sToHWCK failed, node : "
<< (weightTensor->format == schema::Format_KCHW ? "KCHW" : "KHWC"),
node->name.c_str();
MS_LOG(WARNING) << "TransFilter %sToKHWC failed, node : "
<< (weightTensor->format == schema::Format_KHWC ? "KHWC" : "KCHW") << node->name.c_str();
// todo(00445839): consider varible weight condition
}
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be HWCK
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_CKHW) { // from caffe
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format,
weightTensor->dataType;
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCKHW2HWCK);
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<int8_t>(weightTensor.get(), kCKHW2KHWC);
} else if (weightTensor->dataType == kNumberTypeUInt8) {
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCKHW2KHWC);
} else {
MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format,
weightTensor->dataType;
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2HWCK);
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC);
}
} else if (weightTensor->format == schema::Format_HWCK) { // from tf
return 0;
} else if (weightTensor->format == schema::Format_CHWK) { // from onnx
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
if (weightTensor->dataType == kNumberTypeInt8) {
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<int8_t>(weightTensor.get(), kCHWK2KHWC);
MS_LOG(DEBUG) << node->name << " weight trans format: CHWK->KHWC";
} else if (weightTensor->dataType == kNumberTypeUInt8) {
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCHWK2KHWC);
} else {
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
}
} else if (weightTensor->format == schema::Format_KCHW) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<uint8_t>(weightTensor.get(), kKCHW2HWCK);
} else {
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK);
}
} else {
} else if (weightTensor->format != schema::Format_KHWC) {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1;
}
......@@ -295,14 +291,13 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_KHWC;
} else {
MS_LOG(WARNING) << "TransFilter %ToHWCK failed, node : "
<< (weightTensor->format == schema::Format_CHWK ? "CHWK" : "CKHW"),
node->name.c_str();
MS_LOG(WARNING) << "TransFilter" << (weightTensor->format == schema::Format_KHWC ? "KHWC" : "CKHW")
<< "To KHWC failed, node : " << node->name.c_str();
// todo(00445839): consider varible weight condition
}
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be HWCK
node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW;
weightTensor->format = schema::Format_CKHW;
} else { // weight should be HWCK
node->primitive->value.AsDeConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_KHWC;
}
return 0;
}
......@@ -354,7 +349,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
if (graphNode->subGraph->fmkType == converter::FmkType_MS) {
weightTensor->format = schema::Format_CKHW;
}
if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms
if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC);
} else if (weightTensor->format == schema::Format_KCHW) {
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
......@@ -374,8 +369,8 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_CHWK) { // from tf
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
} else if (weightTensor->format == schema::Format_KHWC) { // from tf
status = RET_OK;
} else {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1;
......
......@@ -40,7 +40,8 @@ class ModelParser {
}
return Fb2Anf(Parse(modelFile, weightFile));
}
virtual schema::MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) = 0;
virtual schema::MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) = 0;
public:
static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) {
......
......@@ -31,7 +31,8 @@ CaffeModelParser::~CaffeModelParser() {}
const std::set<std::string> CaffeModelParser::skipedLayerType = {"Dropout"};
schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const std::string &weightFile) {
schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
std::unique_ptr<schema::MetaGraphT> graph(new schema::MetaGraphT());
if (ValidateFileStr(modelFile, ".prototxt") != RET_OK) {
......@@ -91,7 +92,7 @@ schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const
// ConvertCaffeBatchNorm(graph.get());
return graph.release();
// return Fb2Anf(graph.release());
// return Fb2Anf(graph.release());
}
STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op,
......
......@@ -33,7 +33,8 @@ class CaffeModelParser : public ModelParser {
virtual ~CaffeModelParser();
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) override;
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;
private:
void ConvertCaffeBatchNorm(MetaGraphT *meta_graphT);
......
......@@ -37,7 +37,8 @@ class OnnxModelParser : public ModelParser {
public:
OnnxModelParser();
virtual ~OnnxModelParser();
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) override;
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;
private:
TypeId GetDateTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
......
......@@ -20,7 +20,6 @@
#include "tools/common/graph_util.h"
#include "tools/common/storage.h"
#include "flatbuffers/flatbuffers.h"
#include "utils/log_adapter.h"
#include "src/common/file_utils.h"
namespace mindspore {
......@@ -60,42 +59,64 @@ STATUS TfliteModelParser::SetAllTensors(const TensorCache &tensor_cache, schema:
}
return RET_OK;
}
void TfliteModelParser::SetMsTensorFromTflite(const std::unique_ptr<tflite::TensorT> &tflite_tensor,
schema::TensorT *tensor) {
std::unique_ptr<schema::QuantParamT> quant_param(new QuantParamT());
if (!tflite_tensor->quantization->scale.empty()) {
quant_param->scale = tflite_tensor->quantization->scale[0];
}
STATUS TfliteModelParser::ParseTfliteQuantParams(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
auto dst_op = tfliteOpMap.at(tflite_op.get());
if (!tflite_tensor->quantization->zero_point.empty()) {
quant_param->zeroPoint = tflite_tensor->quantization->zero_point[0];
}
std::vector<uint32_t> quant_params_index;
quant_params_index.insert(quant_params_index.end(), tflite_op->inputs.begin(), tflite_op->inputs.end());
quant_params_index.insert(quant_params_index.end(), tflite_op->outputs.begin(), tflite_op->outputs.end());
for (const auto &index : quant_params_index) {
const auto &tflite_tensor = tflite_subgraph->tensors[index];
if (tflite_tensor == nullptr) {
MS_LOG(ERROR) << "tensor with id = " << index <<" is null";
return RET_ERROR;
}
// change quant param min to 0 to fit ms-lite ops
if (tensor->dataType == TypeId::kNumberTypeInt8) {
quant_param->zeroPoint = quant_param->zeroPoint - 128;
}
if (!tflite_tensor->quantization->min.empty()) {
quant_param->min = tflite_tensor->quantization->min[0];
}
if (!tflite_tensor->quantization->max.empty()) {
quant_param->max = tflite_tensor->quantization->max[0];
}
quant_param->inited = true;
tensor->quantParams.clear();
tensor->quantParams.emplace_back(std::move(quant_param));
}
STATUS TfliteModelParser::ParseTfliteQuantParams(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
schema::CNodeT *op, TensorCache *tensor_cache) {
MS_ASSERT(op->outputIndex.size() == tflite_op->outputs.size());
for (size_t i = 0; i < tflite_op->inputs.size() && i < op->inputIndex.size(); i++) {
const auto &tflite_tensor = tflite_subgraph->tensors[tflite_op->inputs.at(i)];
if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() &&
tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) {
continue;
}
std::unique_ptr<schema::QuantParamT> quant_param(new schema::QuantParamT());
if (!tflite_tensor->quantization->scale.empty()) {
quant_param->scale = tflite_tensor->quantization->scale[0];
}
if (!tflite_tensor->quantization->zero_point.empty()) {
quant_param->zeroPoint = tflite_tensor->quantization->zero_point[0];
auto &inTensor = tensor_cache->GetCachedTensor().at(op->inputIndex.at(i));
if (inTensor == nullptr) {
MS_LOG(ERROR) << "Parse tflite quant params inTensor is null";
return RET_NULL_PTR;
}
if (!tflite_tensor->quantization->min.empty()) {
quant_param->min = tflite_tensor->quantization->min[0];
SetMsTensorFromTflite(tflite_tensor, inTensor);
}
for (size_t i = 0; i < tflite_op->outputs.size() && i < op->outputIndex.size(); i++) {
const auto &tflite_tensor = tflite_subgraph->tensors[tflite_op->outputs.at(i)];
if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() &&
tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) {
continue;
}
if (!tflite_tensor->quantization->max.empty()) {
quant_param->max = tflite_tensor->quantization->max[0];
auto &outTensor = tensor_cache->GetCachedTensor().at(op->outputIndex.at(i));
if (outTensor == nullptr) {
MS_LOG(ERROR) << "Parse tflite quant params outTensor is null";
return RET_NULL_PTR;
}
SetMsTensorFromTflite(tflite_tensor, outTensor);
}
dst_op->quantType = schema::QuantType_AwareTrainning;
return RET_OK;
}
......@@ -105,11 +126,15 @@ STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT
for (const auto &index : tflite_op->outputs) {
const auto &tflite_tensor = tflite_subgraph->tensors[index];
if (tflite_tensor == nullptr) {
MS_LOG(ERROR) << "tensor with id = " << index <<" is null";
MS_LOG(ERROR) << "tensor with id = " << index << " is null";
return RET_ERROR;
}
std::unique_ptr<schema::TensorT> tensor(new schema::TensorT());
tensor->dataType = GetTfliteDataType(tflite_tensor->type);
// change dataType to int8 to fit ms-lite op
if (tensor->dataType == TypeId::kNumberTypeUInt8) {
tensor->dataType = TypeId::kNumberTypeInt8;
}
tensor->dims = tflite_tensor->shape;
tensor->nodeType = schema::NodeType_Parameter;
auto opOutputIndex = tensorCache->AddTensor(tflite_tensor->name, tensor.release(), OP_OUTPUT);
......@@ -120,7 +145,8 @@ STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT
STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::OperatorT> &tflite_op, TensorCache *tensorCache) {
const std::unique_ptr<tflite::OperatorT> &tflite_op, schema::CNodeT *op,
TensorCache *tensor_cache) {
auto op_type = GetTfliteNodeType(tflite_op, tflite_model);
std::vector<int32_t> op_inputs(tflite_op->inputs);
if (op_type == "DeConv2D") {
......@@ -130,12 +156,11 @@ STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &t
for (const auto &tflite_index : op_inputs) {
const auto &tflite_tensor = tflite_subgraph->tensors[tflite_index];
if (tflite_tensor == nullptr) {
MS_LOG(ERROR) << "tensor with id = " << tflite_index <<" is null";
MS_LOG(ERROR) << "tensor with id = " << tflite_index << " is null";
return RET_ERROR;
}
auto tensor_name = tflite_tensor->name;
auto op = tfliteOpMap[tflite_op.get()];
unsigned int index = tensorCache->FindTensor(tensor_name);
unsigned int index = tensor_cache->FindTensor(tensor_name);
if (index != -1) {
op->inputIndex.push_back(index);
}
......@@ -146,19 +171,20 @@ STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &t
STATUS TfliteModelParser::ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
schema::MetaGraphT *subGraph,
mindspore::lite::TensorCache *tensorCache) {
schema::MetaGraphT *subGraph, mindspore::lite::TensorCache *tensorCache,
const QuantType &quantType) {
auto i = 0;
for (const auto &tflite_op : tflite_subgraph->operators) {
auto opType = GetTfliteNodeType(tflite_op, tflite_model);
std::unique_ptr<schema::CNodeT> op(new schema::CNodeT);
op->name = opType + "-" + std::to_string(i++);
op->quantType = quantType;
MS_LOG(INFO) << "parse op: " << op->name.c_str();
auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(opType);
if (node_parser == nullptr) {
MS_LOG(ERROR) << "cannot find node parser, opType: "<< opType.c_str();
MS_LOG(ERROR) << "cannot find node parser, opType: " << opType.c_str();
continue;
// return RET_NULL_PTR;
}
......@@ -172,7 +198,19 @@ STATUS TfliteModelParser::ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_
status = SetOpOutputIdx(tflite_subgraph, tflite_op, op.get(), tensorCache);
if (status != RET_OK) {
MS_LOG(ERROR) << "Set Op "<< op->name.c_str() << " Output Index Failed!";
MS_LOG(ERROR) << "set op " << opType.c_str() << " output index failed";
return RET_ERROR;
}
status = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, op.get(), tensorCache);
if (status != RET_OK) {
MS_LOG(ERROR) << "set op " << opType.c_str() << " input index failed";
return RET_ERROR;
}
status = ParseTfliteQuantParams(tflite_subgraph, tflite_op, op.get(), tensorCache);
if (status != RET_OK) {
MS_LOG(ERROR) << "parse op " << opType.c_str() << " quant parameters failed";
return RET_ERROR;
}
......@@ -189,8 +227,10 @@ void TfliteModelParser::SetInputTensor(const std::unique_ptr<tflite::SubGraphT>
const auto &tflite_tensor = tflite_subgraph->tensors[index];
std::unique_ptr<schema::TensorT> tensor(new schema::TensorT());
tensor->format = schema::Format_NHWC;
tensor->dataType = GetTfliteDataType(tflite_tensor->type);
tensor->nodeType = schema::NodeType_ValueNode;
tensor->dataType = GetTfliteDataType(tflite_tensor->type) != TypeId::kNumberTypeUInt8
? GetTfliteDataType(tflite_tensor->type)
: TypeId::kNumberTypeInt8;
tensor->nodeType = schema::NodeType_Parameter;
tensor->dims = tflite_tensor->shape;
tensor_cache->AddTensor(tflite_tensor->name, tensor.release(), GRAPH_INPUT);
}
......@@ -212,7 +252,8 @@ void TfliteModelParser::SetGraphTensorIndex(const mindspore::lite::TensorCache &
}
}
MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::string &weightFile) {
MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
if (ValidateFileStr(modelFile, ".tflite") != RET_OK) {
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.tflite";
return nullptr;
......@@ -224,7 +265,6 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st
MS_LOG(ERROR) << "read tflite model failed";
return nullptr;
}
if (tflite_model->subgraphs.size() != 1) {
MS_LOG(ERROR) << "read tflite model subgraphs failed";
return nullptr;
......@@ -238,30 +278,15 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st
// set dst subGraph op attr and tensor_cache.
std::unique_ptr<schema::MetaGraphT> subGraph(new schema::MetaGraphT);
subGraph->name = "MS_model converted by TF-Lite";
auto status = ParseOp(tflite_model, tflite_subgraph, subGraph.get(), &tensorCache);
auto status = ParseOp(tflite_model, tflite_subgraph, subGraph.get(), &tensorCache, quantType);
if (status != RET_OK) {
MS_LOG(ERROR) << "ParseOp failed.";
return nullptr;
}
for (const auto &tflite_op : tflite_subgraph->operators) {
auto status_tmp = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, &tensorCache);
if (status_tmp != RET_OK) {
MS_LOG(ERROR) << "Set Op " << tfliteOpMap.at(tflite_op.get())->name.c_str() << " Input Index Failed!";
}
}
for (const auto &tflite_op : tflite_subgraph->operators) {
auto statusTmp = ParseTfliteQuantParams(tflite_subgraph, tflite_op);
if (statusTmp != RET_OK) {
MS_LOG(ERROR) << "ParseTfliteQuantParams " << tfliteOpMap.at(tflite_op.get())->name.c_str() << " Failed!";
}
}
SetGraphTensorIndex(tensorCache, subGraph.get());
SetAllTensors(tensorCache, subGraph.get());
return subGraph.release();
}
} // namespace lite
} // namespace mindspore
......@@ -40,22 +40,25 @@ class TfliteModelParser : public ModelParser {
virtual ~TfliteModelParser();
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile);
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;
private:
std::unique_ptr<tflite::ModelT> ReadTfliteModelFromFlat(const char *buf);
void SetMsTensorFromTflite(const std::unique_ptr<tflite::TensorT> &tflite_tensor, schema::TensorT *tensor);
void SetInputTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, TensorCache *tensor_cache);
void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache,
schema::MetaGraphT *subGraphDef);
void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, schema::MetaGraphT *subGraphDef);
STATUS ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::MetaGraphT *sub_graph,
TensorCache *tensor_cache);
TensorCache *tensor_cache, const QuantType &quantType);
STATUS ParseTfliteQuantParams(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::OperatorT> &tflite_op);
const std::unique_ptr<tflite::OperatorT> &tflite_op, schema::CNodeT *op,
TensorCache *tensor_cache);
std::string GetTfliteNodeType(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model);
......@@ -63,13 +66,13 @@ class TfliteModelParser : public ModelParser {
STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *sub_graph);
STATUS SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op, schema::CNodeT *op,
TensorCache *tensorCache);
STATUS SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::OperatorT> &tflite_op, TensorCache *tensorCache);
const std::unique_ptr<tflite::OperatorT> &tflite_op, schema::CNodeT *op,
TensorCache *tensor_cache);
std::map<std::string, schema::CNodeT *> opMap;
std::map<const tflite::OperatorT *, schema::CNodeT *> tfliteOpMap;
......
......@@ -4,7 +4,9 @@ include_directories(${3RD_DIR}/flatbuffers/include)
include_directories(${3RD_DIR}/opencv/build/include/opencv4)
add_library(quantizer_mid OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc
${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MS_AWARE_QUANTIZER_H
#define MS_AWARE_QUANTIZER_H
#include <array>
#include <string>
#include "tools/converter/quantizer/quantizer.h"
#include "schema/inner/model_generated.h"
#include "include/errorcode.h"
namespace mindspore::lite::quant {
struct InputArray;
class AwareQuantizer : public FbQuantizer {
public:
AwareQuantizer(schema::MetaGraphT *graph, const std::string &inputInferType, const std::string &stdValues,
const std::string &meanValues);
~AwareQuantizer() { delete (mInputArray); }
STATUS RemoveFakeQuant() override;
STATUS GenerateQuantParam() override;
STATUS DetermineNodeQuantType() override;
STATUS DoQuantize() override; // override;
private:
// RemoveFakeQuant
STATUS SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node);
STATUS GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph);
STATUS QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node);
STATUS QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node);
STATUS QuantConvBias(const schema::MetaGraphT *graph, schema::CNodeT *node);
STATUS QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node);
float inputScale = 0.0f;
InputArray *mInputArray;
static const std::array<schema::PrimitiveType, 7> propagatedOps;
};
} // namespace mindspore::lite::quant
#endif
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/calc_quant_param.h"
#include <cfloat>
#include <memory>
#include <algorithm>
#include <utility>
#include "tools/common/graph_util.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "schema/inner/ops_generated.h"
#include "src/common/utils.h"
namespace mindspore::lite {
STATUS QuantParamCalcer::ComputeConstQuantParam(const schema::TensorT &tensor, QuantParamT *quantParam) {
MS_ASSERT(quantParam != nullptr);
// int32 weight no need to quant
if (tensor.dataType == TypeId::kNumberTypeInt32 || tensor.dataType == TypeId::kNumberTypeUInt8) {
return RET_OK;
}
if (tensor.dataType != TypeId::kNumberTypeFloat) {
// MS_LOGW("Const Tensor without quantParam should has float dataType, in fact: %d", tensor.dataType);
return RET_ERROR;
}
const auto *constData = reinterpret_cast<const float *>(tensor.data.data());
size_t constTensorShapeSize = GetShapeSize(tensor);
float min = 0.0f;
float max = 0.0f;
// find min and max
for (size_t i = 0; i < constTensorShapeSize; i++) {
min = std::min(min, constData[i]);
max = std::max(max, constData[i]);
}
if (min == 0.0f && max == 0.0f) {
max = 1.0f;
}
bool isQuantExact = true;
for (size_t i = 0; i < constTensorShapeSize; i++) {
isQuantExact &= (constData[i] == min || constData[i] == max);
}
if (!isQuantExact) {
// //MS_LOGD("compute quantParam for const tensor may be a cause of poor inference accuracy");
}
return quant::CalQuantizationParams(quantParam, min, max);
}
// init inTensor quantParam from preNode if possable
// init outTensor quantParam from postNode if possable
int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
MS_ASSERT(node.inputIndex.size() > 0);
MS_ASSERT(node.quantParam.size() == node.inputIndex.size() + node.outputIndex.size());
inputParamDone = 0;
auto inputTensorSize = node.inputIndex.size();
for (size_t i = 0; i < inputTensorSize; i++) {
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i));
auto &tensor = graph->allTensors.at(node.inputIndex.at(i));
MS_ASSERT(tensor != nullptr);
auto quantParam = GetTensorQuantParam(tensor);
if (quantParam->inited) { // inited
inputParamDone++;
continue;
}
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i));
MS_ASSERT(tensor != nullptr);
if (tensor->refCount == schema::NodeType_ValueNode && !IsContain(graph->inputIndex, node.inputIndex.at(i))) {
auto status = ComputeConstQuantParam((*tensor), quantParam.get());
if (status != RET_OK) {
// MS_LOGW("ComputeConstQuantParam failed: %d", status);
return status;
}
tensor->quantParams.front() = std::move(quantParam);
inputParamDone++;
continue;
}
}
outputParamDone = 0;
for (unsigned int i : node.outputIndex) {
MS_ASSERT(graph->allTensors.size() > i);
auto &tensor = graph->allTensors.at(i);
MS_ASSERT(tensor != nullptr);
auto quantParam = GetTensorQuantParam(tensor);
MS_ASSERT(quantParam != nullptr);
if (quantParam->inited) { // inited
outputParamDone++;
continue;
}
if (tensor->refCount == 999) {
MS_ASSERT(false);
}
}
return RET_OK;
}
int CommonCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) {
auto status = QuantParamCalcer::Calc(subGraph, node);
if (status != RET_OK) {
// MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status);
return status;
}
if (inputParamDone != node.inputIndex.size()) {
MS_LOG(ERROR) << "Can not determine inputTensor quantParam, node " << node.name.c_str();
return RET_ERROR;
}
if (outputParamDone != node.outputIndex.size()) {
MS_LOG(ERROR) << "Can not determine outputTensor quantParam, node " << node.name.c_str();
return RET_ERROR;
}
return RET_OK;
}
int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
auto status = QuantParamCalcer::Calc(graph, node);
if (status != RET_OK) {
// MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status);
return status;
}
if (inputParamDone != node.inputIndex.size()) {
MS_ASSERT(graph->allTensors.size() > node.outputIndex.at(0));
auto &outTensor = graph->allTensors.at(node.outputIndex.at(0));
MS_ASSERT(outTensor != nullptr);
auto outputQuantParam = GetTensorQuantParam(outTensor);
MS_ASSERT(outputQuantParam != nullptr);
if (!outputQuantParam->inited) {
// MS_LOGW("Can not determine inputTensor quantParam from outputTensor for node %s", node.name.c_str());
return RET_ERROR;
}
for (unsigned int i : node.inputIndex) {
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i));
auto &inTensor = graph->allTensors.at(i);
MS_ASSERT(inTensor != nullptr);
auto inQuantParam = GetTensorQuantParam(inTensor);
if (inQuantParam->inited) {
continue;
}
inTensor->quantParams.front() = std::move(inQuantParam);
}
}
if (outputParamDone != node.outputIndex.size()) {
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0));
auto &inTensor = graph->allTensors.at(node.inputIndex.at(0));
MS_ASSERT(inTensor != nullptr);
auto inQuantParam = GetTensorQuantParam(inTensor);
if (!inQuantParam->inited) {
// MS_LOGW("Can not determine outputTensor quantParam from inputTensor for node %s", node.name.c_str());
return RET_ERROR;
}
for (size_t i = 0; i < node.outputIndex.size(); i++) {
MS_ASSERT(graph->allTensors.size() > node.outputIndex.at(i));
auto &outTensor = graph->allTensors.at(node.outputIndex.at(i));
MS_ASSERT(outTensor != nullptr);
auto outQuantParam = GetTensorQuantParam(outTensor);
if (outQuantParam->inited) {
continue;
}
// todo copy quant params
outTensor->quantParams.front() = std::move(outQuantParam);
}
}
return RET_OK;
}
class CalcConcat : public QuantParamCalcer {
public:
CalcConcat() = default;
int Calc(MetaGraphT *graph, const CNodeT &node) override {
MS_ASSERT(node.outputIndex.size() == 1);
auto status = QuantParamCalcer::Calc(graph, node);
if (status != RET_OK) {
// MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status);
return status;
}
if (inputParamDone != node.inputIndex.size()) {
// MS_LOGW("Can not determine concat inputTensor quantParam, node %s", node.name.c_str());
return RET_ERROR;
}
if (outputParamDone != 1) {
MS_ASSERT(outputParamDone == 0);
float minMin = FLT_MAX;
float maxMax = FLT_MIN;
bool narrowRange = false;
int numBits = -1;
for (size_t i = 0; i < node.inputIndex.size(); i++) {
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i));
auto &inTensor = graph->allTensors.at(i);
MS_ASSERT(inTensor != nullptr);
auto inQuantParam = GetTensorQuantParam(inTensor);
MS_ASSERT(inQuantParam != nullptr);
if (!inQuantParam->inited) {
return RET_ERROR;
}
if (numBits == -1) {
narrowRange = inQuantParam->narrowRange;
numBits = inQuantParam->numBits;
} else {
MS_ASSERT(narrowRange == quantParam->narrowRange);
MS_ASSERT(numBits == quantParam->numBits);
}
if (minMin > inQuantParam->min) {
minMin = inQuantParam->min;
}
if (maxMax < inQuantParam->max) {
maxMax = inQuantParam->max;
}
}
MS_ASSERT(graph->allTensors.size() > node.outputIndex.front());
auto &outTensor = graph->allTensors.at(node.outputIndex.front());
MS_ASSERT(outTensor != nullptr);
auto outQuantParam = GetTensorQuantParam(outTensor);
status = quant::CalQuantizationParams(outQuantParam.get(), minMin, maxMax, narrowRange, numBits);
if (status != RET_OK) {
// MS_LOGW("in aware quantization run CalQuantizationParams failed!");
return RET_ERROR;
}
outputParamDone++;
}
return RET_OK;
}
};
class CalcAdd : public QuantParamCalcer {
public:
CalcAdd() = default;
int Calc(MetaGraphT *graph, const CNodeT &node) override {
MS_ASSERT(node.inputIndex.size() == 2);
MS_ASSERT(node.outputIndex.size() == 1);
auto status = QuantParamCalcer::Calc(graph, node);
if (status != RET_OK) {
// MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status);
return status;
}
if (inputParamDone != 2) {
// MS_LOGW("Can not determine add inputTensor quantParam, node %s", node.name.c_str());
return RET_ERROR;
}
if (outputParamDone != 1) {
MS_ASSERT(outputParamDone == 0);
MS_ASSERT(graph->allTensors.size() > node.outputIndex.front());
auto &outTensor = graph->allTensors.at(node.outputIndex.front());
MS_ASSERT(outTensor != nullptr);
auto outQuantParam = GetTensorQuantParam(outTensor);
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0));
auto &tensor0 = graph->allTensors.at(node.inputIndex.at(0));
MS_ASSERT(tensor0 != nullptr);
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(1));
auto &tensor1 = graph->allTensors.at(node.inputIndex.at(1));
MS_ASSERT(tensor1 != nullptr);
auto biasTensor = &tensor0;
auto paramTensor = &tensor1;
if (tensor0->refCount == 999 && (tensor0->dims.empty() || tensor0->dims.size() == 1)) {
biasTensor = &tensor0;
paramTensor = &tensor1;
} else if (tensor1->refCount == 999 && (tensor1->dims.empty() || tensor1->dims.size() == 1)) {
biasTensor = &tensor1;
paramTensor = &tensor0;
} else {
// MS_LOGW("Can not determine add outputTensor quantParam, node %s", node.name.c_str());
return RET_ERROR;
}
auto quantParam = GetTensorQuantParam(*paramTensor);
MS_ASSERT(quantParam != nullptr);
MS_ASSERT(quantParam->inited);
auto min = quantParam->min;
auto max = quantParam->max;
{
if ((*biasTensor)->dataType == TypeId::kNumberTypeFloat) {
MS_ASSERT((*biasTensor)->data.size() == sizeof(float) / sizeof(uint8_t));
void *oriTensorData = (*biasTensor)->data.data();
auto *bias = static_cast<float *>(oriTensorData);
status = quant::CalQuantizationParams(outQuantParam.get(), min + (*bias), max + (*bias));
if (status != RET_OK) {
// MS_LOGW("in aware quantization run CalQuantizationParams failed!");
return RET_ERROR;
}
} else if ((*biasTensor)->dataType == TypeId::kNumberTypeUInt8) {
MS_ASSERT((*biasTensor)->data.size() == 1);
void *oriTensorData = (*biasTensor)->data.data();
auto *bias = static_cast<uint8_t *>(oriTensorData);
status = quant::CalQuantizationParams(outQuantParam.get(), min + (*bias), max + (*bias));
if (status != RET_OK) {
// MS_LOGW("in aware quantization run CalQuantizationParams failed!");
return RET_ERROR;
}
} else {
// MS_LOGW("Unsupported tensor dataType: %d", (*biasTensor)->dataType);
return RET_ERROR;
}
}
}
return RET_OK;
}
};
class CalcRealDiv : public QuantParamCalcer {
public:
CalcRealDiv() = default;
int Calc(MetaGraphT *graph, const CNodeT &node) override {
MS_ASSERT(node.inputIndex.size() == 2);
MS_ASSERT(node.outputIndex.size() == 1);
auto status = QuantParamCalcer::Calc(graph, node);
if (status != RET_OK) {
// MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status);
return status;
}
if (inputParamDone != 2) {
// MS_LOGW("Can not determine realdiv inputTensor quantParam, node %s", node.name.c_str());
return RET_ERROR;
}
if (outputParamDone != 1) {
MS_ASSERT(outputParamDone == 0);
MS_ASSERT(graph->allTensors.size() > node.outputIndex.front());
auto &outTensor = graph->allTensors.at(node.outputIndex.front());
MS_ASSERT(outTensor != nullptr);
auto outQuantParam = GetTensorQuantParam(outTensor);
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0));
auto &tensor0 = graph->allTensors.at(node.inputIndex.at(0));
MS_ASSERT(tensor0 != nullptr);
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(1));
auto &tensor1 = graph->allTensors.at(node.inputIndex.at(1));
MS_ASSERT(tensor1 != nullptr);
if (tensor1->refCount == 999 && (tensor1->dims.empty() || tensor1->dims.size() == 1)) {
auto quantParam = GetTensorQuantParam(tensor1);
auto min = quantParam->min;
auto max = quantParam->max;
{
if (tensor1->dataType == TypeId::kNumberTypeFloat) {
MS_ASSERT(tensor1->data.size() == sizeof(float) / sizeof(uint8_t));
void *oriTensorData = tensor1->data.data();
auto *div = static_cast<float *>(oriTensorData);
MS_ASSERT(*div != 0);
status = quant::CalQuantizationParams(outQuantParam.get(), min / (*div), max / (*div));
if (status != RET_OK) {
// MS_LOGW("in aware quantization run CalQuantizationParams failed!");
return RET_ERROR;
}
} else if (tensor1->dataType == TypeId::kNumberTypeUInt8) {
MS_ASSERT(tensor1->data.size() == 1);
void *oriTensorData = tensor1->data.data();
auto *div = static_cast<uint8_t *>(oriTensorData);
status = quant::CalQuantizationParams(outQuantParam.get(), min / (*div), max + (*div));
if (status != RET_OK) {
// MS_LOGW("in aware quantization run CalQuantizationParams failed!");
return RET_ERROR;
}
} else {
// MS_LOGW("Unsupported tensor dataType: %d", tensor1->dataType);
return RET_ERROR;
}
}
} else {
// MS_LOGW("Can not determine realDiv outputTensor quantParam, node %s", node.name.c_str());
return RET_ERROR;
}
}
return RET_OK;
}
};
class CalcToSet : public QuantParamCalcer {
public:
CalcToSet(float min, float max) : min(min), max(max) {}
int Calc(MetaGraphT *graph, const CNodeT &node) override {
MS_ASSERT(node.inputIndex.size() == 1);
MS_ASSERT(node.outputIndex.size() == 1);
auto status = QuantParamCalcer::Calc(graph, node);
if (status != RET_OK) {
// MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status);
return status;
}
// input
if (inputParamDone != node.inputIndex.size()) {
// MS_LOGW("Can not determine inputTensor quantParam, node %s", node.name.c_str());
return RET_ERROR;
}
// output
std::unique_ptr<QuantParamT> quantParam(new (std::nothrow) QuantParamT());
if (quantParam == nullptr) {
// MS_LOGW("new QuantParamT failed");
return RET_ERROR;
}
quantParam->scale = (max - min) / 256;
MS_ASSERT(quantParam->scale != 0);
quantParam->zeroPoint = int32_t(std::round(256 - max / quantParam->scale));
quantParam->min = min;
quantParam->max = max;
quantParam->inited = true;
MS_ASSERT(graph->allTensors.size() > node.outputIndex.front());
auto &outTensor = graph->allTensors.at(node.outputIndex.front());
MS_ASSERT(outTensor != nullptr);
outTensor->quantParams.front() = std::move(quantParam);
return RET_OK;
}
protected:
float min;
float max;
};
class CalcActivation : public QuantParamCalcer {
public:
CalcActivation() = default;
int Calc(MetaGraphT *subGraph, const CNodeT &node) override {
MS_ASSERT(node.inputIndex.size() == 1);
MS_ASSERT(node.outputIndex.size() == 1);
MS_ASSERT(node.attr.AsActivation() != nullptr);
if (node.primitive->value.AsActivation()->type == schema::ActivationType_SIGMOID) {
auto calcToSet = CalcToSet(0, 1);
return calcToSet.Calc(subGraph, node);
} else {
auto calCommon = CommonCalcer();
return calCommon.Calc(subGraph, node);
}
}
};
QuantParamCalcRegister::QuantParamCalcRegister() {
bool hasError = false;
auto baseCalcer = new (std::nothrow) QuantParamCalcer();
if (baseCalcer == nullptr) {
// MS_LOGW("new QuantParamCalcer failed");
hasError = true;
}
auto commonCalcer = new (std::nothrow) CommonCalcer();
if (commonCalcer == nullptr) {
// MS_LOGW("new commonCalcer failed");
hasError = true;
}
auto linearCalcer = new (std::nothrow) LinearCalcer();
if (linearCalcer == nullptr) {
// MS_LOGW("new linearCalcer failed");
hasError = true;
}
if (!hasError) {
_registerMap[schema::PrimitiveType_Concat] = new CalcConcat();
_registerMap[schema::PrimitiveType_Activation] = new CalcActivation();
_registerMap[schema::PrimitiveType_Add] = new CalcAdd();
_registerMap[schema::PrimitiveType_Mul] = commonCalcer;
_registerMap[schema::PrimitiveType_Conv2D] = commonCalcer;
_registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer;
_registerMap[schema::PrimitiveType_Pooling] = linearCalcer;
_registerMap[schema::PrimitiveType_Resize] = linearCalcer;
_registerMap[schema::PrimitiveType_Reshape] = linearCalcer;
_registerMap[schema::PrimitiveType_Shape] = linearCalcer; // todo if shape influence postNode's output quantParam
_registerMap[schema::PrimitiveType_SoftMax] = new CalcToSet(0, 1);
_registerMap[schema::PrimitiveType_Squeeze] = linearCalcer;
_registerMap[schema::PrimitiveType_RealDiv] = new CalcRealDiv();
_registerMap[schema::PrimitiveType_Reduce] = commonCalcer;
_registerMap[schema::PrimitiveType_BiasAdd] = commonCalcer;
_registerMap[schema::PrimitiveType_Mean] = linearCalcer;
_registerMap[schema::PrimitiveType_Transpose] = linearCalcer;
_registerMap[schema::PrimitiveType_MatMul] = commonCalcer;
_registerMap[schema::PrimitiveType_FullConnection] = commonCalcer;
_registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer;
_registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer;
// todo
// detection_postprocess op's quant param will not infer only fetch from preNode or postNode
// because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float.
// if quantTransNode is inserted after detection_postprocess node, there will be some errors
_registerMap[schema::PrimitiveType_DetectionPostProcess] = baseCalcer;
}
}
QuantParamCalcRegister *QuantParamCalcRegister::GetInstance() {
static QuantParamCalcRegister instance;
return &instance;
}
QuantParamCalcer *QuantParamCalcRegister::GetQuantParamCalcer(schema::PrimitiveType opType) {
auto it = _registerMap.find(opType);
if (it != _registerMap.end()) {
return it->second;
}
return nullptr;
}
} // namespace mindspore::lite
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CALC_QUANT_PARAM_H
#define CALC_QUANT_PARAM_H
#include <unordered_map>
#include <memory>
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace lite {
static constexpr int CONVLUTION_INPUT_NUM = 3;
class QuantParamCalcer {
public:
virtual ~QuantParamCalcer() = default;
virtual int Calc(schema::MetaGraphT *graph, const schema::CNodeT &node);
protected:
STATUS ComputeConstQuantParam(const schema::TensorT &tensor, schema::QuantParamT *quantParam);
protected:
size_t inputParamDone = 0;
size_t outputParamDone = 0;
};
class CommonCalcer : public QuantParamCalcer {
public:
CommonCalcer() = default;
~CommonCalcer() override = default;
int Calc(schema::MetaGraphT *subGraph, const schema::CNodeT &node) override;
};
class LinearCalcer : public QuantParamCalcer {
public:
LinearCalcer() = default;
~LinearCalcer() override = default;
int Calc(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
};
class QuantParamCalcRegister {
public:
virtual ~QuantParamCalcRegister() = default;
QuantParamCalcer *GetQuantParamCalcer(schema::PrimitiveType opType);
static QuantParamCalcRegister *GetInstance();
private:
QuantParamCalcRegister();
std::unordered_map<schema::PrimitiveType, QuantParamCalcer *> _registerMap;
};
} // namespace lite
} // namespace mindspore
#endif
......@@ -39,126 +39,127 @@ QuantStrategy::QuantStrategy(size_t weightSize, size_t convWeightQuantChannelThr
: mWeightSize(weightSize), mConvWeightQuantChannelThreshold(convWeightQuantChannelThreshold) {}
bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const {
size_t i = 0;
for (i = 0; i < mConvTypes.size(); i++) {
if (node->fullname_with_scope().find(mConvTypes[i]) == 0) {
break;
}
size_t i = 0;
for (i = 0; i < mConvTypes.size(); i++) {
if (node->fullname_with_scope().find(mConvTypes[i]) == 0) {
break;
}
}
if ((i == mConvTypes.size()) || (node->size() < 3)) {
return false;
}
if ((i == mConvTypes.size()) || (node->size() < 3)) {
return false;
}
auto inputNode = node->input(2);
if (!inputNode->isa<Parameter>()) {
return false;
}
auto paramNode = inputNode->cast<ParameterPtr>();
auto abstract_base = paramNode->abstract();
if (abstract_base == nullptr) {
return false;
}
auto inputNode = node->input(2);
if (!inputNode->isa<Parameter>()) {
return false;
}
auto paramNode = inputNode->cast<ParameterPtr>();
auto abstract_base = paramNode->abstract();
if (abstract_base == nullptr) {
return false;
}
if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name();
return false;
}
auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
size_t shapeSize = 1;
for (auto dim : weight_shape) {
shapeSize = shapeSize * dim;
}
if (shapeSize < mWeightSize) {
MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize;
return false;
}
if (weight_shape[0] <= mConvWeightQuantChannelThreshold) {
MS_LOG(INFO) << "channel less mConvWeightQuantChannelThreshold!" << weight_shape[0];
return false;
}
if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name();
return false;
}
auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
size_t shapeSize = 1;
for (auto dim : weight_shape) {
shapeSize = shapeSize * dim;
}
if (shapeSize < mWeightSize) {
MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize;
return false;
}
if (weight_shape[0] <= mConvWeightQuantChannelThreshold) {
MS_LOG(INFO) << "channel less mConvWeightQuantChannelThreshold!" << weight_shape[0];
return false;
}
return true;
return true;
}
bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
if (!node->isa<CNode>()) {
return false;
}
auto cnode = std::dynamic_pointer_cast<CNode>(node);
if (!node->isa<CNode>()) {
return false;
}
auto cnode = std::dynamic_pointer_cast<CNode>(node);
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
if (primitiveT_value == nullptr) {
MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope();
return false;
}
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
if (primitiveT_value == nullptr) {
MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope();
return false;
}
auto type = primitiveT_value->GetPrimitiveT()->value.type;
MS_LOG(INFO) << "Primitive type: " << type;
static const std::vector<schema::PrimitiveType> uint8OpList = {
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/ schema::PrimitiveType_Reshape,
schema::PrimitiveType_Activation};
return IsContain(uint8OpList, type);
auto type = primitiveT_value->GetPrimitiveT()->value.type;
MS_LOG(INFO) << "Primitive type: " << type;
static const std::vector<schema::PrimitiveType> uint8OpList = {
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/ schema::PrimitiveType_Reshape,
schema::PrimitiveType_Activation};
return IsContain(uint8OpList, type);
}
bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const {
size_t i = 0;
for (i = 0; i < mMulTypes.size(); i++) {
if (node->fullname_with_scope().find(mMulTypes[i]) == 0) {
break;
}
}
if (i == mMulTypes.size()) {
return false;
size_t i = 0;
for (i = 0; i < mMulTypes.size(); i++) {
if (node->fullname_with_scope().find(mMulTypes[i]) == 0) {
break;
}
}
if (i == mMulTypes.size()) {
return false;
}
if (node->size() < 3) {
MS_LOG(INFO) << "input size less!";
return false;
}
if (node->size() < 3) {
MS_LOG(INFO) << "input size less!";
return false;
}
auto inputNode1 = node->input(1);
auto inputNode2 = node->input(2);
if (inputNode1 == nullptr || inputNode2 == nullptr) {
MS_LOG(INFO) << "mul input is nullptr!";
return false;
}
auto inputNode1 = node->input(1);
auto inputNode2 = node->input(2);
if (inputNode1 == nullptr || inputNode2 == nullptr) {
MS_LOG(INFO) << "mul input is nullptr!";
return false;
}
ParameterPtr paramNode = nullptr;
if (inputNode1->isa<Parameter>()) {
paramNode = inputNode1->cast<ParameterPtr>();
} else if (inputNode2->isa<Parameter>()) {
paramNode = inputNode2->cast<ParameterPtr>();
}
ParameterPtr paramNode = nullptr;
if (inputNode1->isa<Parameter>()) {
paramNode = inputNode1->cast<ParameterPtr>();
} else if (inputNode2->isa<Parameter>()) {
paramNode = inputNode2->cast<ParameterPtr>();
}
if (paramNode == nullptr) {
MS_LOG(INFO) << "invalid paramNode!";
return false;
}
if (paramNode == nullptr) {
MS_LOG(INFO) << "invalid paramNode!";
return false;
}
auto abstract_base = paramNode->abstract();
if (abstract_base == nullptr) {
MS_LOG(INFO) << "abstract is nullptr";
return false;
}
auto abstract_base = paramNode->abstract();
if (abstract_base == nullptr) {
MS_LOG(INFO) << "abstract is nullptr";
return false;
}
if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name();
return false;
}
auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
size_t shapeSize = 1;
for (auto dim : weight_shape) {
shapeSize = shapeSize * dim;
}
if (shapeSize < mWeightSize) {
MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize;
return false;
}
if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name();
return false;
}
auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
size_t shapeSize = 1;
for (auto dim : weight_shape) {
shapeSize = shapeSize * dim;
}
if (shapeSize < mWeightSize) {
MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize;
return false;
}
return true;
return true;
}
void CalFakeNode(const AnfNodePtr &inTensor) {
......@@ -190,56 +191,119 @@ void CalFakeNode(const AnfNodePtr &inTensor) {
// }
}
STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin,
double mMax, bool narrowRange, int quant_max, int quant_min, int num_bits) {
MS_ASSERT(quantParam != nullptr);
if (mMin > 0.0f) {
MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision";
mMin = 0.0f;
}
if (mMax < 0.0f) {
MS_LOG(ERROR) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision";
mMax = 0.0f;
}
if (mMin > mMax) {
MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax;
return RET_PARAM_INVALID;
}
if (mMin == mMax) {
if (mMin != 0.0f) {
MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other";
return RET_ERROR;
}
quantParam->inited = true;
quantParam->min = mMin;
quantParam->max = mMax;
quantParam->scale = 0.0f;
quantParam->zeroPoint = 0;
quantParam->narrowRange = narrowRange;
quantParam->numBits = num_bits;
return RET_OK;
STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin, double mMax, bool narrowRange,
int quant_max, int quant_min, int num_bits) {
MS_ASSERT(quantParam != nullptr);
if (mMin > 0.0f) {
MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision";
mMin = 0.0f;
}
if (mMax < 0.0f) {
MS_LOG(ERROR) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision";
mMax = 0.0f;
}
if (mMin > mMax) {
MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax;
return RET_PARAM_INVALID;
}
if (mMin == mMax) {
if (mMin != 0.0f) {
MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other";
return RET_ERROR;
}
auto quantMinFloat = static_cast<double>(quant_min);
auto quantMaxFloat = static_cast<double>(quant_max);
double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
const double zeroPointFromMin = quantMinFloat - mMin / scale;
// const double zeroPointFromMax = quantMaxFloat - mMax / scale;
int zeroPoint = static_cast<int32_t>(std::round(zeroPointFromMin));
// The zero point should always be in the range of quantized value,
// [qmin, qmax].
MS_ASSERT(zeroPoint >= quantMin);
MS_ASSERT(zeroPoint <= quantMax);
quantParam->inited = true;
quantParam->min = mMin;
quantParam->max = mMax;
quantParam->scale = scale;
quantParam->zeroPoint = zeroPoint;
quantParam->scale = 0.0f;
quantParam->zeroPoint = 0;
quantParam->narrowRange = narrowRange;
quantParam->numBits = num_bits;
return RET_OK;
}
auto quantMinFloat = static_cast<double>(quant_min);
auto quantMaxFloat = static_cast<double>(quant_max);
double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
const double zeroPointFromMin = quantMinFloat - mMin / scale;
// const double zeroPointFromMax = quantMaxFloat - mMax / scale;
int zeroPoint = static_cast<int32_t>(std::round(zeroPointFromMin));
// The zero point should always be in the range of quantized value,
// [qmin, qmax].
MS_ASSERT(zeroPoint >= quantMin);
MS_ASSERT(zeroPoint <= quantMax);
quantParam->inited = true;
quantParam->min = mMin;
quantParam->max = mMax;
quantParam->scale = scale;
quantParam->zeroPoint = zeroPoint;
quantParam->narrowRange = narrowRange;
quantParam->numBits = num_bits;
return RET_OK;
}
STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax,
bool narrowRange, int numBits) {
MS_ASSERT(quantParam != nullptr);
if (mMin > 0.0f) {
MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision";
mMin = 0.0f;
}
if (mMax < 0.0f) {
MS_LOG(ERROR) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision";
mMax = 0.0f;
}
if (mMin > mMax) {
MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax;
return RET_PARAM_INVALID;
}
if (mMin == mMax) {
if (mMin != 0.0f) {
MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other";
return RET_ERROR;
}
quantParam->inited = true;
quantParam->min = mMin;
quantParam->max = mMax;
quantParam->scale = 0.0f;
quantParam->zeroPoint = 0;
quantParam->narrowRange = narrowRange;
quantParam->numBits = numBits;
return RET_OK;
}
int quantMin = narrowRange ? 1 : 0;
int quantMax = (1 << (unsigned int)numBits) - 1;
auto quantMinFloat = static_cast<double>(quantMin);
auto quantMaxFloat = static_cast<double>(quantMax);
double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
const double zeroPointFromMin = quantMinFloat - mMin / scale;
const double zeroPointFromMax = quantMaxFloat - mMax / scale;
const double zpFromMinError = std::abs(quantMinFloat) + std::abs(mMin / scale);
const double zpFromMaxError = std::abs(quantMaxFloat) + std::abs(mMax / scale);
const double zpDouble = zpFromMinError < zpFromMaxError ? zeroPointFromMin : zeroPointFromMax;
int zeroPoint;
if (zpDouble < quantMinFloat) {
zeroPoint = quantMin;
} else if (zpDouble > quantMaxFloat) {
zeroPoint = quantMax;
} else {
zeroPoint = static_cast<int32_t>(std::round(zpDouble));
}
// The zero point should always be in the range of quantized value,
// [qmin, qmax].
MS_ASSERT(zeroPoint >= quantMin);
MS_ASSERT(zeroPoint <= quantMax);
quantParam->inited = true;
quantParam->min = mMin;
quantParam->max = mMax;
quantParam->scale = scale;
quantParam->zeroPoint = zeroPoint;
quantParam->narrowRange = narrowRange;
quantParam->numBits = numBits;
return RET_OK;
}
STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum,
......@@ -292,14 +356,14 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
weightPtr->set_quant_param(quantParam);
}
auto ret = memcpy_s(const_cast<float*>(rawDatas), weightPtr->tensor_size(),
qDatas.data(), shapeSize * sizeof(int8_t));
auto ret =
memcpy_s(const_cast<float *>(rawDatas), weightPtr->tensor_size(), qDatas.data(), shapeSize * sizeof(int8_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;
}
if (quantType == QuantType_WeightQuant) {
PostBitPack(const_cast<float*>(rawDatas), shapeSize, bitNum);
PostBitPack(const_cast<float *>(rawDatas), shapeSize, bitNum);
}
weightPtr->set_tensor_type(kNumberTypeInt8);
......@@ -338,14 +402,13 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
qDatas[i] = quant_max;
} else if (quant_data < quant_min) {
qDatas[i] = quant_min;
} else {
} else {
qDatas[i] = static_cast<int8_t>(quant_data);
}
}
weightPtr->set_quant_param(quantParam);
auto ret = memcpy_s(rawDatas, weightPtr->tensor_size(),
qDatas.data(), shapeSize * sizeof(int8_t));
auto ret = memcpy_s(rawDatas, weightPtr->tensor_size(), qDatas.data(), shapeSize * sizeof(int8_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;
......@@ -358,34 +421,32 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
weightPtr->set_tensor_size(shapeSize * sizeof(int8_t));
}
return RET_OK;
return RET_OK;
}
STATUS PostBitPack(float *weight, size_t shapeSize, size_t bitNum) {
auto *rawDatas = reinterpret_cast<uint8_t *>(weight);
vector<uint8_t> qDatas(rawDatas, rawDatas + shapeSize);
vector<uint8_t> qDatas_packed;
if (bitNum < 8 && bitNum > 1) {
BitPack weight_bitpack(bitNum);
weight_bitpack.BitPacking(qDatas, qDatas_packed);
if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas_packed[0], shapeSize)) {
MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas_packed failed";
return RET_ERROR;
}
} else if (bitNum == 8) {
if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas[0], shapeSize)) {
MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas failed";
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "bitNum must be between 0 and 8 : " << bitNum;
return RET_ERROR;
auto *rawDatas = reinterpret_cast<uint8_t *>(weight);
vector<uint8_t> qDatas(rawDatas, rawDatas + shapeSize);
vector<uint8_t> qDatas_packed;
if (bitNum < 8 && bitNum > 1) {
BitPack weight_bitpack(bitNum);
weight_bitpack.BitPacking(qDatas, qDatas_packed);
if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas_packed[0], shapeSize)) {
MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas_packed failed";
return RET_ERROR;
}
} else if (bitNum == 8) {
if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas[0], shapeSize)) {
MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas failed";
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "bitNum must be between 0 and 8 : " << bitNum;
return RET_ERROR;
}
return RET_OK;
return RET_OK;
}
} // namespace quant
} // namespace lite
} // namespace mindspore
......@@ -62,6 +62,41 @@ class QuantStrategy {
STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin, double mMax,
bool narrowRange, int quant_max, int quant_min, int num_bits);
STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax,
bool narrowRange = false, int numBits = UINT8_QUANTIZATION);
template <typename T>
T QuantizeData(const float originData, const schema::QuantParamT *quantParam) {
MS_ASSERT(quantParam != nullptr);
MS_ASSERT(quantParam->inited);
const auto scale = quantParam->scale;
const auto zeroPoint = quantParam->zeroPoint;
const auto numBit = quantParam->numBits;
const auto narrowRange = quantParam->narrowRange;
const double maxLimit = static_cast<float>((1 << (unsigned int)numBit) - 1 - zeroPoint) * scale;
double minLimit;
if (narrowRange) {
minLimit = static_cast<float>(1 - zeroPoint) * scale;
} else {
minLimit = static_cast<float>(0 - zeroPoint) * scale;
}
return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
double tmp = 0.0f;
if (originData > maxLimit) {
tmp = maxLimit;
} else if (originData < minLimit) {
tmp = minLimit;
} else {
tmp = originData;
}
auto quantData = static_cast<T>(std::round(tmp / scale + zeroPoint));
if (quantData == 0 && narrowRange) {
quantData++;
}
return quantData;
}();
}
template <typename T>
T QuantizeData(float originData, const AnfQuantParam *quantParam, int quant_max, int quant_min) {
MS_ASSERT(quantParam != nullptr);
......
......@@ -15,22 +15,19 @@
*/
#include "mindspore/lite/tools/converter/quantizer/quantizer.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace lite {
namespace quant {
Quantizer::Quantizer(FuncGraphPtr graph) : funcGraph(graph) {
if (funcGraph == nullptr) {
return;
}
}
namespace mindspore::lite::quant {
STATUS Quantizer::GenerateQuantParam() { return RET_OK; }
STATUS Quantizer::RemoveFakeQuant() { return RET_OK; }
STATUS Quantizer::DetermineNodeQuantType() { return RET_OK; }
} // namespace quant
} // namespace lite
} // namespace mindspore
STATUS FbQuantizer::GenerateQuantParam() { return RET_OK; }
STATUS FbQuantizer::RemoveFakeQuant() { return RET_OK; }
STATUS FbQuantizer::DetermineNodeQuantType() { return RET_OK; }
} // namespace mindspore::lite::quant
......@@ -18,48 +18,63 @@
#define MS_QUANTIZER_H
#include <unordered_map>
#include <utility>
#include <memory>
#include "include/errorcode.h"
#include "ir/func_graph.h"
#include "ir/anf.h"
#include "include/model.h"
#include "base/base.h"
#include "src/param_value_lite.h"
#include "schema/inner/model_generated.h"
#include "tools/converter/converter_flags.h"
namespace mindspore {
namespace lite {
namespace quant {
namespace mindspore::lite::quant {
using STATUS = int;
enum QuantType {
QuantType_QUANT_NONE = 0,
QuantType_AwareTraining = 1,
QuantType_WeightQuant = 2,
QuantType_PostTraining = 3,
QuantType_MIN = QuantType_QUANT_NONE,
QuantType_MAX = QuantType_PostTraining
QuantType_QUANT_NONE = 0,
QuantType_AwareTraining = 1,
QuantType_WeightQuant = 2,
QuantType_PostTraining = 3,
QuantType_MIN = QuantType_QUANT_NONE,
QuantType_MAX = QuantType_PostTraining
};
class Quantizer {
public:
explicit Quantizer(FuncGraphPtr graph);
explicit Quantizer(FuncGraphPtr graph) : funcGraph(std::move(graph)) {}
~Quantizer() = default;
~Quantizer() = default;
virtual STATUS RemoveFakeQuant();
virtual STATUS RemoveFakeQuant();
virtual STATUS GenerateQuantParam();
virtual STATUS GenerateQuantParam();
virtual STATUS DetermineNodeQuantType();
virtual STATUS DetermineNodeQuantType();
virtual STATUS DoQuantize(FuncGraphPtr funcGraph) = 0;
virtual STATUS DoQuantize(FuncGraphPtr funcGraph) = 0;
mindspore::lite::converter::Flags flags;
protected:
FuncGraphPtr funcGraph = nullptr;
FuncGraphPtr funcGraph = nullptr;
};
} // namespace quant
} // namespace lite
} // namespace mindspore
#endif
class FbQuantizer {
public:
explicit FbQuantizer(schema::MetaGraphT *graph) : graph(graph) {}
~FbQuantizer() = default;
virtual STATUS RemoveFakeQuant();
virtual STATUS GenerateQuantParam();
virtual STATUS DetermineNodeQuantType();
virtual STATUS DoQuantize() = 0;
protected:
std::shared_ptr<schema::MetaGraphT> graph = nullptr;
};
} // namespace mindspore::lite::quant
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册