提交 da1c32a7 编写于 作者: C cjh9368

fix bug for converter_flags

上级 2ef32167
......@@ -29,12 +29,12 @@ Flags::Flags() {
AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", "");
AddFlag(&Flags::weightFile, "weightFile",
"Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", "");
AddFlag(&Flags::inferenceType, "inferenceType",
"Real data type saved in output file, reserved param, NOT used for now. FLOAT | FP16 | UINT8", "FLOAT");
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | WeightQuant | PostTraining", "");
AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | UINT8", "FLOAT");
AddFlag(&Flags::inferenceTypeIn, "inferenceType",
"Real data type saved in output file, reserved param, NOT used for now. FLOAT | INT8", "FLOAT");
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining", "");
AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | INT8", "FLOAT");
AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128");
AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "127");
AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5");
AddFlag(&Flags::quantSize, "quantSize", "Weight quantization size threshold", "0");
AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", "");
AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true");
......@@ -77,14 +77,24 @@ int Flags::Init(int argc, const char **argv) {
}
if (this->inputInferenceTypeIn == "FLOAT") {
this->inputInferenceType = TypeId::kNumberTypeFloat;
} else if (this->inputInferenceTypeIn == "UINT8") {
this->inputInferenceType = TypeId::kNumberTypeUInt8;
} else if (this->inputInferenceTypeIn == "INT8") {
this->inputInferenceType = TypeId::kNumberTypeInt8;
} else {
std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s", this->inputInferenceTypeIn.c_str();
std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s, supported inputInferenceType: FLOAT | INT8",
this->inputInferenceTypeIn.c_str();
return 1;
}
if (this->inferenceTypeIn == "FLOAT") {
this->inferenceType = TypeId::kNumberTypeFloat;
} else if (this->inferenceTypeIn == "INT8") {
this->inferenceType = TypeId::kNumberTypeInt8;
} else {
std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8",
this->inferenceTypeIn.c_str();
return 1;
}
if (this->fmkIn == "CAFFE") {
this->fmk = FmkType_CAFFE;
} else if (this->fmkIn == "MS") {
......
......@@ -63,10 +63,10 @@ class Flags : public virtual mindspore::lite::FlagParser {
// used for quantization
std::string quantTypeIn;
QuantType quantType;
std::string inferenceType;
std::string inferenceTypeIn;
TypeId inferenceType = TypeId::kNumberTypeFloat;
// used for parse aware trainning
std::string inputInferenceTypeIn;
// mindspore::predict::DataType inputInferenceType = DataType_DT_FLOAT;
TypeId inputInferenceType = TypeId::kNumberTypeFloat;
std::string stdDev;
std::string mean;
......
......@@ -194,6 +194,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
return RET_ERROR;
}
dTypeTransPass->SetInputDataDType(ctx.inputInferenceType);
dTypeTransPass->SetOutputDataDType(ctx.inferenceType);
quantNodeOptimizer.AddPass(dTypeTransPass);
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
......
......@@ -101,7 +101,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
if (inputDataDType == TypeId::kNumberTypeInt8) {
if (outputDataDType == TypeId::kNumberTypeInt8) {
return RET_OK;
}
MS_ASSERT(inputDataDType == TypeId::kNumberTypeFloat);
......@@ -231,5 +231,8 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte
}
void DTypeTransPass::SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; }
void DTypeTransPass::SetOutputDataDType(TypeId dataType) { this->outputDataDType = dataType; }
} // namespace lite
} // namespace mindspore
......@@ -38,6 +38,8 @@ class DTypeTransPass : public GraphPass {
void SetInputDataDType(TypeId dataType);
void SetOutputDataDType(TypeId dataType);
private:
STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph);
......@@ -51,6 +53,7 @@ class DTypeTransPass : public GraphPass {
private:
size_t id;
TypeId inputDataDType = TypeId::kNumberTypeFloat;
TypeId outputDataDType = TypeId::kNumberTypeFloat;
OpDefCopyer castOpCopyer = [](schema::CNodeT *inCNode) -> std::unique_ptr<schema::CNodeT> {
std::unique_ptr<schema::CNodeT> newCNode(new (std::nothrow) schema::CNodeT);
......
......@@ -88,7 +88,7 @@ AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph,
if (inputInferType == "FLOAT") {
inArr.reset(new (std::nothrow) InputArray(mean, stdValue));
} else {
inArr.reset(new (std::nothrow) InputArray(mean, stdValue, TypeId::kNumberTypeUInt8));
inArr.reset(new (std::nothrow) InputArray(mean, stdValue, TypeId::kNumberTypeInt8));
}
mInputArray = inArr.get();
mInputArray->InitQuantParam();
......
......@@ -37,8 +37,8 @@ struct InputArray {
InputArray(float mean, float stdDev,
TypeId dataType = TypeId::kNumberTypeFloat) {
this->dataType = dataType;
constexpr float qmin = 0;
constexpr float qmax = 255;
constexpr float qmin = -128;
constexpr float qmax = 127;
mMin = (qmin - mean) / stdDev;
mMax = (qmax - mean) / stdDev;
}
......
......@@ -246,8 +246,8 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
return RET_OK;
}
int quantMin = narrowRange ? 1 : 0;
int quantMax = (1 << (unsigned int) numBits) - 1;
int quantMin = narrowRange ? 1 : 0 - 128;
int quantMax = (1 << (unsigned int) numBits) - 1 - 128;
auto quantMinFloat = static_cast<double>(quantMin);
auto quantMaxFloat = static_cast<double>(quantMax);
double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册