提交 15564504 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4972 fix post training quant

Merge pull request !4972 from xutianchun/quant_0822
...@@ -227,10 +227,10 @@ struct DivergInfo { ...@@ -227,10 +227,10 @@ struct DivergInfo {
int zero_point = 0; int zero_point = 0;
if (quant_min == 0 && quant_max == 255) { if (quant_min == 0 && quant_max == 255) {
zero_point = 128; zero_point = 128;
} else if (quant_min == -128 && quant_max == 127) { } else if (quant_min == -127 && quant_max == 127) {
zero_point = 0; zero_point = 0;
} else { } else {
MS_LOG(ERROR) << "unexpectd quant range, quant_min: " << quant_min << " quant_max: " << quant_max; MS_LOG(WARNING) << "unexpectd quant range, quant_min: " << quant_min << " quant_max: " << quant_max;
} }
return std::make_pair(this->cnode, zero_point); return std::make_pair(this->cnode, zero_point);
} }
...@@ -486,7 +486,7 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, in ...@@ -486,7 +486,7 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, in
this->target_type_ = target_type; this->target_type_ = target_type;
if (target_type == kNumberTypeInt8) { if (target_type == kNumberTypeInt8) {
quant_max = (1 << (this->bit_num - 1)) - 1; // 127 quant_max = (1 << (this->bit_num - 1)) - 1; // 127
quant_min = -(1 << (this->bit_num - 1)); // -128 quant_min = -quant_max; // -127
} else if (target_type == kNumberTypeUInt8) { } else if (target_type == kNumberTypeUInt8) {
quant_max = (1 << this->bit_num) - 1; // 255 quant_max = (1 << this->bit_num) - 1; // 255
quant_min = 0; quant_min = 0;
......
...@@ -100,7 +100,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { ...@@ -100,7 +100,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
} }
std::vector<AnfNodePtr> op_inputs = {value_node, input_cnode}; std::vector<AnfNodePtr> op_inputs = {value_node, input_cnode};
auto quant_cast_cnode = graph->NewCNode(op_inputs); auto quant_cast_cnode = graph->NewCNode(op_inputs);
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast_" + std::to_string(i));
cnode->set_input(i, quant_cast_cnode); cnode->set_input(i, quant_cast_cnode);
MS_LOG(DEBUG) << "Add quant cast. " MS_LOG(DEBUG) << "Add quant cast. "
<< "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type
......
...@@ -220,11 +220,11 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl ...@@ -220,11 +220,11 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
bool narrowRange, int numBits) { bool narrowRange, int numBits) {
MS_ASSERT(quantParam != nullptr); MS_ASSERT(quantParam != nullptr);
if (mMin > 0.0f) { if (mMin > 0.0f) {
MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision"; MS_LOG(DEBUG) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision";
mMin = 0.0f; mMin = 0.0f;
} }
if (mMax < 0.0f) { if (mMax < 0.0f) {
MS_LOG(ERROR) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision"; MS_LOG(DEBUG) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision";
mMax = 0.0f; mMax = 0.0f;
} }
if (mMin > mMax) { if (mMin > mMax) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册