提交 406ce735 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4624 adjust MS model quant param

Merge pull request !4624 from yankai10/merge
......@@ -47,7 +47,15 @@ class PrimitiveTValue : public Value {
}
}
void SetInputQuantParam(std::vector<std::vector<schema::QuantParamT>> vec_quant_param) {}
void SetInputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) {
this->input_quant_param_ = input_quant_param;
}
void SetOutputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param) {
this->output_quant_param_ = output_quant_param;
}
void AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) {
this->input_quant_param_.emplace_back(quant_param);
......
......@@ -37,8 +37,13 @@ int Nchw2NhwcCPUKernel::Run() {
auto output = out_tensors_[0];
if (input->shape().size() == 4) {
PackNCHWToNHWCFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
output->Channel());
if (input->data_type() == kNumberTypeFloat32) {
PackNCHWToNHWCFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
output->Channel());
} else if (input->data_type() == kNumberTypeInt8) {
PackNCHWToNHWCInt8(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
output->Channel());
}
} else {
memcpy(output->Data(), input->Data(), input->ElementsNum() * sizeof(float));
}
......@@ -67,4 +72,5 @@ kernel::LiteKernel *CpuNchw2NhwcFp32KernelCreator(const std::vector<lite::tensor
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator)
} // namespace mindspore::kernel
......@@ -37,8 +37,13 @@ int Nhwc2NchwCPUKernel::Run() {
auto output = out_tensors_[0];
if (input->shape().size() == 4) {
PackNHWCToNCHWFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
output->Channel());
if (input->data_type() == kNumberTypeFloat32) {
PackNHWCToNCHWFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
output->Channel());
} else if (input->data_type() == kNumberTypeInt8) {
PackNHWCToNCHWInt8(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
output->Channel());
}
} else {
memcpy(output->Data(), input->Data(), input->ElementsNum() * sizeof(float));
}
......@@ -67,4 +72,5 @@ kernel::LiteKernel *CpuNhwc2NchwFp32KernelCreator(const std::vector<lite::tensor
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator)
} // namespace mindspore::kernel
......@@ -978,6 +978,19 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int
return;
}
void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel) {
for (int n = 0; n < batch; n++) {
for (int c = 0; c < channel; c++) {
for (int hw = 0; hw < plane; hw++) {
int nhwc_index = n * channel * plane + hw * channel + c;
int nchw_index = n * channel * plane + c * plane + hw;
((int8_t *)dst)[nchw_index] = ((int8_t *)src)[nhwc_index];
}
}
}
return;
}
void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
return PackNHWCToNCHWFp32(src, dst, batch, channel, plane);
}
......
......@@ -60,6 +60,8 @@ void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int c
void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
......
......@@ -122,8 +122,10 @@ void AnfConvPopulater::CalQuantParam(const double &mean, const double &stdDev, f
*mMax = static_cast<float>((qmax - mean) / stdDev);
}
void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) {
void AnfConvPopulater::PopulaterQuantParam(
const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
auto narrow_range = prim->GetAttr("narrow_range");
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
auto num_bits = prim->GetAttr("num_bits");
......@@ -154,7 +156,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants);
vecInputQuantParam->emplace_back(quants);
quants.clear();
int biasQuantSize = 0;
......@@ -173,7 +175,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
}
vecQuantParam->emplace_back(quants);
vecInputQuantParam->emplace_back(quants);
}
quants.clear();
......@@ -181,10 +183,12 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quantParam.min = 0.0;
quantParam.max = 0.0;
quantParam.zeroPoint = 0;
quantParam.scale = vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale;
quantParam.scale =
vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale;
quants.emplace_back(quantParam);
}
vecQuantParam->emplace_back(quants);
vecInputQuantParam->emplace_back(quants);
quants.clear();
auto outputMin = prim->GetAttr("output_minq");
......@@ -199,7 +203,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants);
vecOutputQuantParam->emplace_back(quants);
}
}
......@@ -215,10 +219,13 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primit
PopulaterConv2DSingleGroup(prim, primitive, group);
}
primitiveTValuePtr->SetPrimitiveT(primitive.release());
if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecQuantParam;
PopulaterQuantParam(prim, &vecQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecQuantParam);
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam);
primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam);
}
return 0;
}
......
......@@ -20,9 +20,10 @@
#ifndef MINDSPORE_ANF_CONV_PARSER_H
#define MINDSPORE_ANF_CONV_PARSER_H
#include "tools/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
#include <memory>
#include <vector>
#include "tools/anf_importer/anf_populater/anf_node_populater.h"
namespace mindspore::lite {
class AnfConvPopulater : public AnfNodePopulater {
public:
......@@ -32,12 +33,18 @@ class AnfConvPopulater : public AnfNodePopulater {
const std::vector<AnfNodePtr> &inputs) override;
private:
void PopulaterConv2DMultiGroup(const PrimitivePtr &prim, const std::unique_ptr<schema::PrimitiveT> &primitive,
const int &group);
void PopulaterConv2DSingleGroup(const PrimitivePtr &prim, const std::unique_ptr<schema::PrimitiveT> &primitive,
const int &group);
void PopulaterQuantParam(const PrimitivePtr &prim, std::vector<std::vector<schema::QuantParamT>> *vecQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
void PopulaterConv2DMultiGroup(
const PrimitivePtr &prim,
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group);
void PopulaterConv2DSingleGroup(
const PrimitivePtr &prim,
const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group);
void PopulaterQuantParam(
const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin,
float *mMax);
};
} // namespace mindspore::lite
......
......@@ -31,8 +31,10 @@ void AnfDepwiseconv2DPopulater::CalQuantParam(const double &mean, const double &
*mMax = static_cast<float>((qmax - mean) / stdDev);
}
void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) {
void AnfDepwiseconv2DPopulater::PopulaterQuantParam(
const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
auto narrow_range = prim->GetAttr("narrow_range");
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
auto num_bits = prim->GetAttr("num_bits");
......@@ -63,7 +65,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants);
vecInputQuantParam->emplace_back(quants);
quants.clear();
int biasQuantSize = 0;
......@@ -82,7 +84,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
}
vecQuantParam->emplace_back(quants);
vecInputQuantParam->emplace_back(quants);
}
quants.clear();
......@@ -90,10 +92,12 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quantParam.min = 0.0;
quantParam.max = 0.0;
quantParam.zeroPoint = 0;
quantParam.scale = vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale;
quantParam.scale =
vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale;
quants.emplace_back(quantParam);
}
vecQuantParam->emplace_back(quants);
vecInputQuantParam->emplace_back(quants);
quants.clear();
auto outputMin = prim->GetAttr("output_minq");
......@@ -108,7 +112,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants);
vecOutputQuantParam->emplace_back(quants);
}
}
......@@ -177,10 +181,12 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
if (primitiveTValuePtr->GetQuantType()) {
std::vector<std::vector<schema::QuantParamT>> vecQuantParam;
PopulaterQuantParam(prim, &vecQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecQuantParam);
if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam);
primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam);
}
return 0;
}
......
......@@ -28,8 +28,12 @@ class AnfDepwiseconv2DPopulater : public AnfNodePopulater {
const std::vector<AnfNodePtr> &inputs) override;
private:
void PopulaterQuantParam(const PrimitivePtr &prim, std::vector<std::vector<schema::QuantParamT>> *vecQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
void PopulaterQuantParam(
const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin,
float *mMax);
};
} // namespace mindspore::lite
......
......@@ -30,8 +30,10 @@ void AnfMatmulPopulater::CalQuantParam(const double &mean, const double &stdDev,
*mMax = static_cast<float>((qmax - mean) / stdDev);
}
void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) {
void AnfMatmulPopulater::PopulaterQuantParam(
const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
auto narrow_range = prim->GetAttr("narrow_range");
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
auto num_bits = prim->GetAttr("num_bits");
......@@ -62,7 +64,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants);
vecInputQuantParam->emplace_back(quants);
quants.clear();
auto filterMin = prim->GetAttr("filter_minq");
......@@ -79,7 +81,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
}
vecQuantParam->emplace_back(quants);
vecInputQuantParam->emplace_back(quants);
}
quants.clear();
......@@ -95,7 +97,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam);
quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants);
vecOutputQuantParam->emplace_back(quants);
}
}
......@@ -110,12 +112,13 @@ int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *prim
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
if (primitiveTValuePtr->GetQuantType()) {
std::vector<std::vector<schema::QuantParamT>> vecQuantParam;
PopulaterQuantParam(prim, &vecQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecQuantParam);
if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam);
primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam);
}
return 0;
}
AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater());
......
......@@ -26,8 +26,12 @@ class AnfMatmulPopulater : public AnfNodePopulater {
const std::vector<AnfNodePtr> &inputs) override;
private:
void PopulaterQuantParam(const PrimitivePtr &prim, std::vector<std::vector<schema::QuantParamT>> *vecQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax);
void PopulaterQuantParam(
const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin,
float *mMax);
};
} // namespace mindspore::lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册