提交 406ce735 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4624 adjust MS model quant param

Merge pull request !4624 from yankai10/merge
...@@ -47,7 +47,15 @@ class PrimitiveTValue : public Value { ...@@ -47,7 +47,15 @@ class PrimitiveTValue : public Value {
} }
} }
void SetInputQuantParam(std::vector<std::vector<schema::QuantParamT>> vec_quant_param) {}
void SetInputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) {
this->input_quant_param_ = input_quant_param;
}
void SetOutputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param) {
this->output_quant_param_ = output_quant_param;
}
void AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) { void AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) {
this->input_quant_param_.emplace_back(quant_param); this->input_quant_param_.emplace_back(quant_param);
......
...@@ -37,8 +37,13 @@ int Nchw2NhwcCPUKernel::Run() { ...@@ -37,8 +37,13 @@ int Nchw2NhwcCPUKernel::Run() {
auto output = out_tensors_[0]; auto output = out_tensors_[0];
if (input->shape().size() == 4) { if (input->shape().size() == 4) {
PackNCHWToNHWCFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(), if (input->data_type() == kNumberTypeFloat32) {
output->Channel()); PackNCHWToNHWCFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
output->Channel());
} else if (input->data_type() == kNumberTypeInt8) {
PackNCHWToNHWCInt8(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
output->Channel());
}
} else { } else {
memcpy(output->Data(), input->Data(), input->ElementsNum() * sizeof(float)); memcpy(output->Data(), input->Data(), input->ElementsNum() * sizeof(float));
} }
...@@ -67,4 +72,5 @@ kernel::LiteKernel *CpuNchw2NhwcFp32KernelCreator(const std::vector<lite::tensor ...@@ -67,4 +72,5 @@ kernel::LiteKernel *CpuNchw2NhwcFp32KernelCreator(const std::vector<lite::tensor
} }
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel
...@@ -37,8 +37,13 @@ int Nhwc2NchwCPUKernel::Run() { ...@@ -37,8 +37,13 @@ int Nhwc2NchwCPUKernel::Run() {
auto output = out_tensors_[0]; auto output = out_tensors_[0];
if (input->shape().size() == 4) { if (input->shape().size() == 4) {
PackNHWCToNCHWFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(), if (input->data_type() == kNumberTypeFloat32) {
output->Channel()); PackNHWCToNCHWFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
output->Channel());
} else if (input->data_type() == kNumberTypeInt8) {
PackNHWCToNCHWInt8(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(),
output->Channel());
}
} else { } else {
memcpy(output->Data(), input->Data(), input->ElementsNum() * sizeof(float)); memcpy(output->Data(), input->Data(), input->ElementsNum() * sizeof(float));
} }
...@@ -67,4 +72,5 @@ kernel::LiteKernel *CpuNhwc2NchwFp32KernelCreator(const std::vector<lite::tensor ...@@ -67,4 +72,5 @@ kernel::LiteKernel *CpuNhwc2NchwFp32KernelCreator(const std::vector<lite::tensor
} }
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel
...@@ -978,6 +978,19 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int ...@@ -978,6 +978,19 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int
return; return;
} }
void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel) {
for (int n = 0; n < batch; n++) {
for (int c = 0; c < channel; c++) {
for (int hw = 0; hw < plane; hw++) {
int nhwc_index = n * channel * plane + hw * channel + c;
int nchw_index = n * channel * plane + c * plane + hw;
((int8_t *)dst)[nchw_index] = ((int8_t *)src)[nhwc_index];
}
}
}
return;
}
void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
return PackNHWCToNCHWFp32(src, dst, batch, channel, plane); return PackNHWCToNCHWFp32(src, dst, batch, channel, plane);
} }
......
...@@ -60,6 +60,8 @@ void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int c ...@@ -60,6 +60,8 @@ void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int c
void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel); void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel);
......
...@@ -122,8 +122,10 @@ void AnfConvPopulater::CalQuantParam(const double &mean, const double &stdDev, f ...@@ -122,8 +122,10 @@ void AnfConvPopulater::CalQuantParam(const double &mean, const double &stdDev, f
*mMax = static_cast<float>((qmax - mean) / stdDev); *mMax = static_cast<float>((qmax - mean) / stdDev);
} }
void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim, void AnfConvPopulater::PopulaterQuantParam(
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) { const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
auto narrow_range = prim->GetAttr("narrow_range"); auto narrow_range = prim->GetAttr("narrow_range");
bool narrowRangeQuantParam = GetValue<bool>(narrow_range); bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
auto num_bits = prim->GetAttr("num_bits"); auto num_bits = prim->GetAttr("num_bits");
...@@ -154,7 +156,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim, ...@@ -154,7 +156,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam); numbitsRangeQuantParam);
quants.emplace_back(quantParam); quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants); vecInputQuantParam->emplace_back(quants);
quants.clear(); quants.clear();
int biasQuantSize = 0; int biasQuantSize = 0;
...@@ -173,7 +175,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim, ...@@ -173,7 +175,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
numbitsRangeQuantParam); numbitsRangeQuantParam);
quants.emplace_back(quantParam); quants.emplace_back(quantParam);
} }
vecQuantParam->emplace_back(quants); vecInputQuantParam->emplace_back(quants);
} }
quants.clear(); quants.clear();
...@@ -181,10 +183,12 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim, ...@@ -181,10 +183,12 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quantParam.min = 0.0; quantParam.min = 0.0;
quantParam.max = 0.0; quantParam.max = 0.0;
quantParam.zeroPoint = 0; quantParam.zeroPoint = 0;
quantParam.scale = vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale;
quantParam.scale =
vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale;
quants.emplace_back(quantParam); quants.emplace_back(quantParam);
} }
vecQuantParam->emplace_back(quants); vecInputQuantParam->emplace_back(quants);
quants.clear(); quants.clear();
auto outputMin = prim->GetAttr("output_minq"); auto outputMin = prim->GetAttr("output_minq");
...@@ -199,7 +203,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim, ...@@ -199,7 +203,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam); numbitsRangeQuantParam);
quants.emplace_back(quantParam); quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants); vecOutputQuantParam->emplace_back(quants);
} }
} }
...@@ -215,10 +219,13 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primit ...@@ -215,10 +219,13 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primit
PopulaterConv2DSingleGroup(prim, primitive, group); PopulaterConv2DSingleGroup(prim, primitive, group);
} }
primitiveTValuePtr->SetPrimitiveT(primitive.release()); primitiveTValuePtr->SetPrimitiveT(primitive.release());
if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) { if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecQuantParam; std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
PopulaterQuantParam(prim, &vecQuantParam); std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
primitiveTValuePtr->SetInputQuantParam(vecQuantParam); PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam);
primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam);
} }
return 0; return 0;
} }
......
...@@ -20,9 +20,10 @@ ...@@ -20,9 +20,10 @@
#ifndef MINDSPORE_ANF_CONV_PARSER_H #ifndef MINDSPORE_ANF_CONV_PARSER_H
#define MINDSPORE_ANF_CONV_PARSER_H #define MINDSPORE_ANF_CONV_PARSER_H
#include "tools/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
#include <memory> #include <memory>
#include <vector>
#include "tools/anf_importer/anf_populater/anf_node_populater.h"
namespace mindspore::lite { namespace mindspore::lite {
class AnfConvPopulater : public AnfNodePopulater { class AnfConvPopulater : public AnfNodePopulater {
public: public:
...@@ -32,12 +33,18 @@ class AnfConvPopulater : public AnfNodePopulater { ...@@ -32,12 +33,18 @@ class AnfConvPopulater : public AnfNodePopulater {
const std::vector<AnfNodePtr> &inputs) override; const std::vector<AnfNodePtr> &inputs) override;
private: private:
void PopulaterConv2DMultiGroup(const PrimitivePtr &prim, const std::unique_ptr<schema::PrimitiveT> &primitive, void PopulaterConv2DMultiGroup(
const int &group); const PrimitivePtr &prim,
void PopulaterConv2DSingleGroup(const PrimitivePtr &prim, const std::unique_ptr<schema::PrimitiveT> &primitive, const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group);
const int &group); void PopulaterConv2DSingleGroup(
void PopulaterQuantParam(const PrimitivePtr &prim, std::vector<std::vector<schema::QuantParamT>> *vecQuantParam); const PrimitivePtr &prim,
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group);
void PopulaterQuantParam(
const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin,
float *mMax);
}; };
} // namespace mindspore::lite } // namespace mindspore::lite
......
...@@ -31,8 +31,10 @@ void AnfDepwiseconv2DPopulater::CalQuantParam(const double &mean, const double & ...@@ -31,8 +31,10 @@ void AnfDepwiseconv2DPopulater::CalQuantParam(const double &mean, const double &
*mMax = static_cast<float>((qmax - mean) / stdDev); *mMax = static_cast<float>((qmax - mean) / stdDev);
} }
void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim, void AnfDepwiseconv2DPopulater::PopulaterQuantParam(
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) { const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
auto narrow_range = prim->GetAttr("narrow_range"); auto narrow_range = prim->GetAttr("narrow_range");
bool narrowRangeQuantParam = GetValue<bool>(narrow_range); bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
auto num_bits = prim->GetAttr("num_bits"); auto num_bits = prim->GetAttr("num_bits");
...@@ -63,7 +65,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim, ...@@ -63,7 +65,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam); numbitsRangeQuantParam);
quants.emplace_back(quantParam); quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants); vecInputQuantParam->emplace_back(quants);
quants.clear(); quants.clear();
int biasQuantSize = 0; int biasQuantSize = 0;
...@@ -82,7 +84,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim, ...@@ -82,7 +84,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
numbitsRangeQuantParam); numbitsRangeQuantParam);
quants.emplace_back(quantParam); quants.emplace_back(quantParam);
} }
vecQuantParam->emplace_back(quants); vecInputQuantParam->emplace_back(quants);
} }
quants.clear(); quants.clear();
...@@ -90,10 +92,12 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim, ...@@ -90,10 +92,12 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quantParam.min = 0.0; quantParam.min = 0.0;
quantParam.max = 0.0; quantParam.max = 0.0;
quantParam.zeroPoint = 0; quantParam.zeroPoint = 0;
quantParam.scale = vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale;
quantParam.scale =
vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale;
quants.emplace_back(quantParam); quants.emplace_back(quantParam);
} }
vecQuantParam->emplace_back(quants); vecInputQuantParam->emplace_back(quants);
quants.clear(); quants.clear();
auto outputMin = prim->GetAttr("output_minq"); auto outputMin = prim->GetAttr("output_minq");
...@@ -108,7 +112,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim, ...@@ -108,7 +112,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam); numbitsRangeQuantParam);
quants.emplace_back(quantParam); quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants); vecOutputQuantParam->emplace_back(quants);
} }
} }
...@@ -177,10 +181,12 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu ...@@ -177,10 +181,12 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu
MS_ASSERT(primitiveTValuePtr != nullptr); MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release()); primitiveTValuePtr->SetPrimitiveT(primitive.release());
if (primitiveTValuePtr->GetQuantType()) { if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecQuantParam; std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
PopulaterQuantParam(prim, &vecQuantParam); std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
primitiveTValuePtr->SetInputQuantParam(vecQuantParam); PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam);
primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam);
} }
return 0; return 0;
} }
......
...@@ -28,8 +28,12 @@ class AnfDepwiseconv2DPopulater : public AnfNodePopulater { ...@@ -28,8 +28,12 @@ class AnfDepwiseconv2DPopulater : public AnfNodePopulater {
const std::vector<AnfNodePtr> &inputs) override; const std::vector<AnfNodePtr> &inputs) override;
private: private:
void PopulaterQuantParam(const PrimitivePtr &prim, std::vector<std::vector<schema::QuantParamT>> *vecQuantParam); void PopulaterQuantParam(
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin,
float *mMax);
}; };
} // namespace mindspore::lite } // namespace mindspore::lite
......
...@@ -30,8 +30,10 @@ void AnfMatmulPopulater::CalQuantParam(const double &mean, const double &stdDev, ...@@ -30,8 +30,10 @@ void AnfMatmulPopulater::CalQuantParam(const double &mean, const double &stdDev,
*mMax = static_cast<float>((qmax - mean) / stdDev); *mMax = static_cast<float>((qmax - mean) / stdDev);
} }
void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim, void AnfMatmulPopulater::PopulaterQuantParam(
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) { const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
auto narrow_range = prim->GetAttr("narrow_range"); auto narrow_range = prim->GetAttr("narrow_range");
bool narrowRangeQuantParam = GetValue<bool>(narrow_range); bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
auto num_bits = prim->GetAttr("num_bits"); auto num_bits = prim->GetAttr("num_bits");
...@@ -62,7 +64,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim, ...@@ -62,7 +64,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam); numbitsRangeQuantParam);
quants.emplace_back(quantParam); quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants); vecInputQuantParam->emplace_back(quants);
quants.clear(); quants.clear();
auto filterMin = prim->GetAttr("filter_minq"); auto filterMin = prim->GetAttr("filter_minq");
...@@ -79,7 +81,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim, ...@@ -79,7 +81,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
numbitsRangeQuantParam); numbitsRangeQuantParam);
quants.emplace_back(quantParam); quants.emplace_back(quantParam);
} }
vecQuantParam->emplace_back(quants); vecInputQuantParam->emplace_back(quants);
} }
quants.clear(); quants.clear();
...@@ -95,7 +97,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim, ...@@ -95,7 +97,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
numbitsRangeQuantParam); numbitsRangeQuantParam);
quants.emplace_back(quantParam); quants.emplace_back(quantParam);
vecQuantParam->emplace_back(quants); vecOutputQuantParam->emplace_back(quants);
} }
} }
...@@ -110,12 +112,13 @@ int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *prim ...@@ -110,12 +112,13 @@ int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *prim
primitive->value.value = attr.release(); primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr); MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release()); primitiveTValuePtr->SetPrimitiveT(primitive.release());
if (primitiveTValuePtr->GetQuantType()) { if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) {
std::vector<std::vector<schema::QuantParamT>> vecQuantParam; std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
PopulaterQuantParam(prim, &vecQuantParam); std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
primitiveTValuePtr->SetInputQuantParam(vecQuantParam); PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam);
primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam);
} }
return 0; return 0;
} }
AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater()); AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater());
......
...@@ -26,8 +26,12 @@ class AnfMatmulPopulater : public AnfNodePopulater { ...@@ -26,8 +26,12 @@ class AnfMatmulPopulater : public AnfNodePopulater {
const std::vector<AnfNodePtr> &inputs) override; const std::vector<AnfNodePtr> &inputs) override;
private: private:
void PopulaterQuantParam(const PrimitivePtr &prim, std::vector<std::vector<schema::QuantParamT>> *vecQuantParam); void PopulaterQuantParam(
void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); const PrimitivePtr &prim,
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam);
void CalQuantParam(const double &mean, const double &stdDev, float *mMin,
float *mMax);
}; };
} // namespace mindspore::lite } // namespace mindspore::lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册