提交 1fc91411 编写于 作者: X xutianchun

Post Training Quantization

上级 0df5a561
...@@ -177,31 +177,44 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { ...@@ -177,31 +177,44 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
if (node->quantType == schema::QuantType_PostTraining) { if (node->quantType == schema::QuantType_PostTraining) {
MS_LOG(INFO) << "node: " << node->name << " add QuantParam"; MS_LOG(INFO) << "node: " << node->name << " add QuantParam";
// activation // activation
auto activate_index = node->inputIndex[0];
auto tensor_input = metaGraphT->allTensors[activate_index].get();
auto input_quant_params = primitiveT_value->GetInputQuantParams(); auto input_quant_params = primitiveT_value->GetInputQuantParams();
if (input_quant_params.empty()) { auto node_type = primitiveT_value->GetPrimitiveT()->value.type;
MS_LOG(WARNING) << "node: " << node->name for (int i = 0; i < input_quant_params.size(); i++) {
<< " input quant params is empty"; if (i >= node->inputIndex.size()) {
} else { MS_LOG(ERROR) << "node: " << node->name << " input has " << input_quant_params.size()
<< " quant_params; but only " << node->inputIndex.size() << " input";
break;
}
auto activate_index = node->inputIndex[i];
auto tensor_input = metaGraphT->allTensors[activate_index].get();
std::unique_ptr<schema::QuantParamT> input_quant_param = std::unique_ptr<schema::QuantParamT> input_quant_param =
std::make_unique<schema::QuantParamT>(input_quant_params[0]); std::make_unique<schema::QuantParamT>(input_quant_params[i]);
MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param->scale
<< " zp: " << input_quant_param->zeroPoint;
tensor_input->quantParams.emplace_back(std::move(input_quant_param)); tensor_input->quantParams.emplace_back(std::move(input_quant_param));
if (!(node_type == schema::PrimitiveType_QuantDTypeCast &&
primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->srcT == kNumberTypeFloat32)) {
tensor_input->dataType = kNumberTypeInt8;
}
} }
tensor_input->dataType = kNumberTypeInt8;
// output // output
auto output_index = node->outputIndex[0]; auto output_index = node->outputIndex[0];
auto tensor_output = metaGraphT->allTensors[output_index].get(); auto tensor_output = metaGraphT->allTensors[output_index].get();
auto output_quant_params = primitiveT_value->GetOutputQuantParams(); auto output_quant_params = primitiveT_value->GetOutputQuantParams();
if (output_quant_params.empty()) { if (output_quant_params.empty()) {
MS_LOG(WARNING) << "node: " << node->name MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty";
<< " output quant params is empty";
} else { } else {
std::unique_ptr<schema::QuantParamT> output_quant_param = std::unique_ptr<schema::QuantParamT> output_quant_param =
std::make_unique<schema::QuantParamT>(output_quant_params[0]); std::make_unique<schema::QuantParamT>(output_quant_params[0]);
MS_LOG(DEBUG) << "[output]node: " << node->name << " scale: " << output_quant_param->scale
<< " zp: " << output_quant_param->zeroPoint;
tensor_output->quantParams.emplace_back(std::move(output_quant_param)); tensor_output->quantParams.emplace_back(std::move(output_quant_param));
} }
tensor_output->dataType = kNumberTypeInt8; if (!(node_type == schema::PrimitiveType_QuantDTypeCast &&
primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) {
tensor_output->dataType = kNumberTypeInt8;
}
// // TensorType // // TensorType
// valuePtr = primitive->GetAttr(kInputTensorDataType); // valuePtr = primitive->GetAttr(kInputTensorDataType);
// if (valuePtr != nullptr) { // if (valuePtr != nullptr) {
......
...@@ -64,6 +64,16 @@ int LiteSession::ConvertTensors(const lite::Model *model) { ...@@ -64,6 +64,16 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
// no copy data, do copy when call LiteKernel::Init // no copy data, do copy when call LiteKernel::Init
dstTensor->SetData(const_cast<unsigned char *>(srcTensor->data()->data())); dstTensor->SetData(const_cast<unsigned char *>(srcTensor->data()->data()));
} }
auto quant_params = srcTensor->quantParams();
if (quant_params != nullptr) {
for (int j = 0; j < quant_params->size(); j++) {
tensor::QuantArg quant_arg{};
quant_arg.scale = quant_params->Get(j)->scale();
quant_arg.zeroPoint = quant_params->Get(j)->zeroPoint();
dstTensor->AddQuantParam(quant_arg);
}
}
this->tensors.emplace_back(dstTensor); this->tensors.emplace_back(dstTensor);
} }
return RET_OK; return RET_OK;
......
...@@ -30,6 +30,7 @@ int QuantDTypeCast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto ...@@ -30,6 +30,7 @@ int QuantDTypeCast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
auto param = primitive->value_as_QuantDTypeCast(); auto param = primitive->value_as_QuantDTypeCast();
MS_ASSERT(input->data_type() == param->srcT); MS_ASSERT(input->data_type() == param->srcT);
output->set_data_type(static_cast<TypeId>(param->dstT())); output->set_data_type(static_cast<TypeId>(param->dstT()));
output->SetFormat(input->GetFormat());
return RET_OK; return RET_OK;
} }
} // namespace mindspore::lite } // namespace mindspore::lite
...@@ -58,7 +58,7 @@ int QuantDTypeCastCPUKernel::Init() { ...@@ -58,7 +58,7 @@ int QuantDTypeCastCPUKernel::Init() {
} }
inverse_ = true; inverse_ = true;
} else { } else {
MS_LOG(ERROR) << "param data type not supported."; MS_LOG(ERROR) << "param data type not supported:" << " src: " << param->srcT << " dst: " << param->dstT;
return RET_ERROR; return RET_ERROR;
} }
...@@ -143,7 +143,6 @@ kernel::LiteKernel *CpuQuantDTypeCastFp32KernelCreator(const std::vector<lite::t ...@@ -143,7 +143,6 @@ kernel::LiteKernel *CpuQuantDTypeCastFp32KernelCreator(const std::vector<lite::t
} }
return kernel; return kernel;
} }
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel
...@@ -23,7 +23,7 @@ int DequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_ ...@@ -23,7 +23,7 @@ int DequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_
} }
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
real_values[i] = (quant_values[i] + zp) * scale; real_values[i] = (quant_values[i] - zp) * scale;
} }
return NNACL_OK; return NNACL_OK;
} }
...@@ -34,7 +34,14 @@ int QuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_ ...@@ -34,7 +34,14 @@ int QuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_
} }
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
quant_values[i] = (int8_t)round(real_values[i] / scale + zp); float temp = round(real_values[i] / scale + zp);
if (temp > 127) {
quant_values[i] = 127;
} else if (temp < -128) {
quant_values[i] = -128;
} else {
quant_values[i] = (int8_t)temp;
}
} }
return NNACL_OK; return NNACL_OK;
} }
...@@ -166,6 +166,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { ...@@ -166,6 +166,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
return -1; return -1;
} }
} }
MS_LOG(DEBUG) << "weight_tensor_format: " << weightTensor->format;
return 0; return 0;
} else if (fmkType == converter::FmkType_ONNX) { } else if (fmkType == converter::FmkType_ONNX) {
switch (node->quantType) { switch (node->quantType) {
...@@ -217,7 +218,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { ...@@ -217,7 +218,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
auto opType = node->primitive->value.type; auto opType = node->primitive->value.type;
if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D && if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D &&
opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) { opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) {
return 0; return RET_OK;
} }
MS_ASSERT(node->inputIndex.size() >= 2); MS_ASSERT(node->inputIndex.size() >= 2);
...@@ -225,7 +226,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { ...@@ -225,7 +226,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
MS_ASSERT(subGraph->allTensors.size() > weightIndex); MS_ASSERT(subGraph->allTensors.size() > weightIndex);
auto &weightTensor = subGraph->allTensors[weightIndex]; auto &weightTensor = subGraph->allTensors[weightIndex];
MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT
STATUS status; STATUS status = RET_OK;
if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK
if (weightTensor->format == schema::Format_KCHW) { // from caffe if (weightTensor->format == schema::Format_KCHW) { // from caffe
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
...@@ -238,11 +239,12 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { ...@@ -238,11 +239,12 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK); status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK);
} }
} else if (weightTensor->format == schema::Format_KHWC) { // from onnx } else if (weightTensor->format == schema::Format_KHWC) { // from onnx
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { return RET_OK;
status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK); // if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
} else { // status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK);
status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK); // } else {
} // status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK);
// }
} else if (weightTensor->format == schema::Format_HWCK) { // from tf } else if (weightTensor->format == schema::Format_HWCK) { // from tf
return 0; return 0;
} else { } else {
...@@ -273,8 +275,8 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { ...@@ -273,8 +275,8 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
} else if (weightTensor->format == schema::Format_HWCK) { // from tf } else if (weightTensor->format == schema::Format_HWCK) { // from tf
return 0; return 0;
} else if (weightTensor->format == schema::Format_CHWK) { // from onnx } else if (weightTensor->format == schema::Format_CHWK) { // from onnx
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCHWK2HWCK); status = TransFilterFormat<int8_t>(weightTensor.get(), kCHWK2KHWC);
} else { } else {
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2HWCK); status = TransFilterFormat<float>(weightTensor.get(), kCHWK2HWCK);
} }
......
...@@ -54,7 +54,7 @@ struct DivergInfo { ...@@ -54,7 +54,7 @@ struct DivergInfo {
size_t bit_num; size_t bit_num;
int quant_max = 255; int quant_max = 255;
int quant_min = 0; int quant_min = 0;
DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max = 255, int quant_min = 0) { DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min) {
this->cnode = cnode; this->cnode = cnode;
this->bin_num = bins; this->bin_num = bins;
this->bit_num = bits; this->bit_num = bits;
...@@ -81,6 +81,9 @@ struct DivergInfo { ...@@ -81,6 +81,9 @@ struct DivergInfo {
STATUS UpdateHistogram(const std::vector<float> &data, const std::vector<int> &shape) { STATUS UpdateHistogram(const std::vector<float> &data, const std::vector<int> &shape) {
for (auto value : data) { for (auto value : data) {
if (value == 0) {
continue;
}
int bin_index = std::min(static_cast<int>(std::fabs(value) / this->interval), bin_num - 1); int bin_index = std::min(static_cast<int>(std::fabs(value) / this->interval), bin_num - 1);
this->histogram[bin_index]++; this->histogram[bin_index]++;
} }
...@@ -470,8 +473,10 @@ STATUS Calibrator::ReadConfig() { ...@@ -470,8 +473,10 @@ STATUS Calibrator::ReadConfig() {
Calibrator::Calibrator(string path, size_t bitNum, int quantMax, int quantMin) Calibrator::Calibrator(string path, size_t bitNum, int quantMax, int quantMin)
: config_path_(path), bit_num_(bitNum), quant_max_(quantMax), quant_min_(quantMin) {} : config_path_(path), bit_num_(bitNum), quant_max_(quantMax), quant_min_(quantMin) {}
PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type) PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type,
bool per_channel)
: Quantizer(graph) { : Quantizer(graph) {
this->per_channel_ = per_channel;
this->bit_num = bit_num; this->bit_num = bit_num;
this->target_type_ = target_type; this->target_type_ = target_type;
if (target_type == kNumberTypeInt8) { if (target_type == kNumberTypeInt8) {
...@@ -533,7 +538,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr node) { ...@@ -533,7 +538,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr node) {
} }
auto parameter = std::dynamic_pointer_cast<Parameter>(node); auto parameter = std::dynamic_pointer_cast<Parameter>(node);
ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param()); ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param());
auto status = QuantFilter(paramValue, QuantType_PostTraining, quant_max, quant_min, bit_num); auto status = QuantFilter(paramValue, QuantType_PostTraining, quant_max, quant_min, bit_num, per_channel_);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status; MS_LOG(ERROR) << "QuantFilter failed: " << status;
return status; return status;
...@@ -670,18 +675,32 @@ STATUS PostTrainingQuantizer::QuantNode() { ...@@ -670,18 +675,32 @@ STATUS PostTrainingQuantizer::QuantNode() {
MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; MS_LOG(ERROR) << "PrimitiveT_value is nullptr";
continue; continue;
} }
if (input_scale.find(cnode) == input_scale.end()) { if (input_scale.find(cnode) == input_scale.end()) {
primitiveT_value->SetQuantType(schema::QuantType_QUANT_NONE); primitiveT_value->SetQuantType(schema::QuantType_QUANT_NONE);
continue; continue;
} }
auto input_vec = cnode->inputs(); auto input_vec = cnode->inputs();
auto op_name = cnode->fullname_with_scope(); auto op_name = cnode->fullname_with_scope();
auto op_type = primitiveT_value->GetPrimitiveT()->value.type;
MS_LOG(INFO) << "OpName: " << op_name; MS_LOG(INFO) << "OpName: " << op_name;
if (input_vec.size() <= 3 && op_name != "Conv2D" && op_name != "DepthwiseConv2D") { if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D) {
MS_LOG(INFO) << "todo(x): "; for (auto i = 1; i < cnode->inputs().size(); i++) {
// int32_t qnodeOutputZeropoint = outputZeropoint[cnode]; auto input_node = cnode->input(i);
// p->AddAttr(kInputTensorDataType, MakeValue((int)targetType)); if (!input_node->isa<mindspore::CNode>()) {
MS_LOG(WARNING) << "node: " << cnode_name << " input " << i << " not a cnode";
continue;
}
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
auto input_cnode_primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(input_cnode->input(0));
if (input_cnode_primitiveT_value == nullptr) {
MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": "
<< " PrimitiveTValue is null";
continue;
}
for (auto &quant_param : input_cnode_primitiveT_value->GetOutputQuantParams()) {
primitiveT_value->AddInputQuantParam(quant_param);
}
}
} else { } else {
// do input quant // do input quant
double scale = input_scale[cnode]; double scale = input_scale[cnode];
......
...@@ -55,15 +55,18 @@ struct ConfigParam { ...@@ -55,15 +55,18 @@ struct ConfigParam {
class PostTrainingQuantizer : public Quantizer { class PostTrainingQuantizer : public Quantizer {
public: public:
PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8); PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8,
bool per_channel = false);
STATUS DoQuantize(FuncGraphPtr funcGraph) override; STATUS DoQuantize(FuncGraphPtr funcGraph) override;
size_t bit_num; size_t bit_num;
int quant_max{255}; int quant_max{127};
int quant_min{0}; int quant_min{-128};
private: private:
bool per_channel_;
TypeId target_type_{kNumberTypeInt8}; TypeId target_type_{kNumberTypeInt8};
std::unique_ptr<Calibrator> calibrator_; std::unique_ptr<Calibrator> calibrator_;
......
...@@ -25,10 +25,11 @@ namespace mindspore::lite::quant { ...@@ -25,10 +25,11 @@ namespace mindspore::lite::quant {
ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params) { ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params) {
std::unique_ptr<schema::PrimitiveT> primitive = std::make_unique<schema::PrimitiveT>(); std::unique_ptr<schema::PrimitiveT> primitive = std::make_unique<schema::PrimitiveT>();
schema::QuantDTypeCastT quant_dtype_cast; schema::QuantDTypeCastT quant_dtype_cast;
quant_dtype_cast.srcT = src_type; // kNumberTypeUInt8; quant_dtype_cast.srcT = src_type; // kNumberTypeInt8;
quant_dtype_cast.dstT = dst_type; // kNumberTypeFloat32; quant_dtype_cast.dstT = dst_type; // kNumberTypeFloat32;
primitive->value.Set(quant_dtype_cast); primitive->value.Set(quant_dtype_cast);
auto primTValue = std::make_shared<PrimitiveTValue>(primitive.release()); auto primTValue = std::make_shared<PrimitiveTValue>(primitive.release());
primTValue->SetQuantType(schema::QuantType_PostTraining);
for (auto &quant_param : quant_params) { for (auto &quant_param : quant_params) {
primTValue->AddInputQuantParam(quant_param); primTValue->AddInputQuantParam(quant_param);
} }
...@@ -52,7 +53,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { ...@@ -52,7 +53,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
if (first) { if (first) {
if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) { if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) {
auto value_node = auto value_node =
NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8, primitiveT_value->GetInputQuantParams()); NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams());
std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)}; std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)};
auto quant_cast_cnode = graph->NewCNode(op_inputs); auto quant_cast_cnode = graph->NewCNode(op_inputs);
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast");
...@@ -82,11 +83,11 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { ...@@ -82,11 +83,11 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
ValueNodePtr value_node = nullptr; ValueNodePtr value_node = nullptr;
if (curnode_quant_type == schema::QuantType_PostTraining && if (curnode_quant_type == schema::QuantType_PostTraining &&
input_cnode_quant_type == schema::QuantType_QUANT_NONE) { input_cnode_quant_type == schema::QuantType_QUANT_NONE) {
value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8, value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8,
input_cnode_primitiveT_value->GetInputQuantParams()); primitiveT_value->GetInputQuantParams());
} else if (curnode_quant_type == schema::QuantType_QUANT_NONE && } else if (curnode_quant_type == schema::QuantType_QUANT_NONE &&
input_cnode_quant_type == schema::QuantType_PostTraining) { input_cnode_quant_type == schema::QuantType_PostTraining) {
value_node = NewQuantCastValueNode(kNumberTypeUInt8, kNumberTypeFloat32, value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32,
input_cnode_primitiveT_value->GetInputQuantParams()); input_cnode_primitiveT_value->GetInputQuantParams());
} }
if (value_node == nullptr) { if (value_node == nullptr) {
......
...@@ -98,7 +98,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { ...@@ -98,7 +98,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
static const std::vector<schema::PrimitiveType> uint8OpList = { static const std::vector<schema::PrimitiveType> uint8OpList = {
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax, schema::PrimitiveType_Reshape, schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/ schema::PrimitiveType_Reshape,
schema::PrimitiveType_Activation}; schema::PrimitiveType_Activation};
return IsContain(uint8OpList, type); return IsContain(uint8OpList, type);
} }
...@@ -242,64 +242,122 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double ...@@ -242,64 +242,122 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double
return RET_OK; return RET_OK;
} }
STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum) { STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum,
bool per_channel) {
if (per_channel) {
// per channel
auto dims = weightPtr->tensor_shape(); auto dims = weightPtr->tensor_shape();
if (dims.size() < 1) { if (dims.size() < 1) {
MS_LOG(ERROR) << "weight dims size error"; MS_LOG(ERROR) << "weight dims size error";
return RET_ERROR; return RET_ERROR;
} }
uint32_t channels = dims[0]; // todo(x)
uint32_t channels = dims[3];
if (channels == 0) { if (channels == 0) {
MS_LOG(ERROR) << "channels error 0"; MS_LOG(ERROR) << "channels error 0";
return RET_ERROR; return RET_ERROR;
} }
size_t shapeSize = weightPtr->tensor_shape_size(); size_t shapeSize = weightPtr->tensor_shape_size();
size_t oneFilterSize = shapeSize / channels; size_t oneFilterSize = shapeSize / channels;
auto *rawDatas = reinterpret_cast<const float *>(weightPtr->tensor_addr()); auto *rawDatas = reinterpret_cast<const float *>(weightPtr->tensor_addr());
if (rawDatas == nullptr) { if (rawDatas == nullptr) {
MS_LOG(ERROR) << "rawDatas is nullptr"; MS_LOG(ERROR) << "rawDatas is nullptr";
return RET_ERROR; return RET_ERROR;
} }
weightPtr->quant_param().clear(); weightPtr->quant_param().clear();
vector<uint8_t> qDatas(shapeSize); vector<int8_t> qDatas(shapeSize);
for (uint32_t i = 0; i < channels; i++) { for (uint32_t i = 0; i < channels; i++) {
float min = 0; float min = 0;
float max = 0; float max = 0;
// find min and max // find min and max
for (uint32_t j = 0; j < oneFilterSize; j++) { for (uint32_t j = 0; j < oneFilterSize; j++) {
min = std::min(min, rawDatas[j + i * oneFilterSize]); min = std::min(min, rawDatas[j + i * oneFilterSize]);
max = std::max(max, rawDatas[j + i * oneFilterSize]); max = std::max(max, rawDatas[j + i * oneFilterSize]);
} }
std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam);
STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum);
if (status != RET_OK) {
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
return status;
}
// update data and datatype
for (uint32_t j = 0; j < oneFilterSize; j++) {
float rawData = rawDatas[j + i * oneFilterSize];
auto qData = QuantizeData<int8_t>(rawData, quantParam.get(), quant_max, quant_min);
qDatas[j + i * oneFilterSize] = qData;
}
weightPtr->set_quant_param(quantParam);
}
auto ret = memcpy_s(const_cast<float*>(rawDatas), weightPtr->tensor_size(),
qDatas.data(), shapeSize * sizeof(int8_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;
}
if (quantType == QuantType_WeightQuant) {
PostBitPack(const_cast<float*>(rawDatas), shapeSize, bitNum);
}
std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam); weightPtr->set_tensor_type(kNumberTypeInt8);
STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum); weightPtr->set_tensor_size(shapeSize * sizeof(int8_t));
if (status != RET_OK) { } else {
MS_LOG(ERROR) << "CalQuantizationParams failed" << status; // per layer
return status; size_t shapeSize = weightPtr->tensor_shape_size();
} auto *rawDatas = static_cast<float *>(weightPtr->tensor_addr());
// update data and datatype if (rawDatas == nullptr) {
for (uint32_t j = 0; j < oneFilterSize; j++) { MS_LOG(ERROR) << "rawDatas is nullptr";
float rawData = rawDatas[j + i * oneFilterSize]; return RET_ERROR;
auto qData = QuantizeData<uint8_t>(rawData, quantParam.get()); }
qDatas[j + i * oneFilterSize] = qData;
}
weightPtr->set_quant_param(quantParam); weightPtr->quant_param().clear();
vector<int8_t> qDatas(shapeSize);
float min = 0;
float max = 0;
for (uint32_t i = 0; i < shapeSize; i++) {
// find max min
min = std::min(min, rawDatas[i]);
max = std::max(max, rawDatas[i]);
} }
auto ret = memcpy_s(const_cast<float*>(rawDatas), weightPtr->tensor_size(),
qDatas.data(), shapeSize * sizeof(uint8_t)); std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam);
STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum);
if (status != RET_OK) {
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
return status;
}
// update data and datatype
for (uint32_t i = 0; i < shapeSize; i++) {
float rawData = rawDatas[i];
auto quant_data = std::round(rawData / quantParam->scale + quantParam->zeroPoint);
if (quant_data > quant_max) {
qDatas[i] = quant_max;
} else if (quant_data < quant_min) {
qDatas[i] = quant_min;
} else {
qDatas[i] = static_cast<int8_t>(quant_data);
}
}
weightPtr->set_quant_param(quantParam);
auto ret = memcpy_s(rawDatas, weightPtr->tensor_size() * sizeof(int8_t),
qDatas.data(), shapeSize * sizeof(int8_t));
if (ret != EOK) { if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret; MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR; return RET_ERROR;
} }
if (quantType == QuantType_WeightQuant) { if (quantType == QuantType_WeightQuant) {
PostBitPack(const_cast<float*>(rawDatas), shapeSize, bitNum); PostBitPack(rawDatas, shapeSize, bitNum);
} }
weightPtr->set_tensor_type(kNumberTypeInt8); weightPtr->set_tensor_type(kNumberTypeInt8);
weightPtr->set_tensor_size(shapeSize * sizeof(int8_t)); weightPtr->set_tensor_size(shapeSize * sizeof(int8_t));
}
return RET_OK; return RET_OK;
} }
......
...@@ -63,41 +63,30 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double ...@@ -63,41 +63,30 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double
bool narrowRange, int quant_max, int quant_min, int num_bits); bool narrowRange, int quant_max, int quant_min, int num_bits);
template <typename T> template <typename T>
T QuantizeData(const float originData, const AnfQuantParam *quantParam) { T QuantizeData(float originData, const AnfQuantParam *quantParam, int quant_max, int quant_min) {
MS_ASSERT(quantParam != nullptr); MS_ASSERT(quantParam != nullptr);
MS_ASSERT(quantParam->inited); MS_ASSERT(quantParam->inited);
const auto scale = quantParam->scale; const auto scale = quantParam->scale;
const auto zeroPoint = quantParam->zeroPoint; const int zeroPoint = quantParam->zeroPoint;
const auto numBit = quantParam->numBits;
const auto narrowRange = quantParam->narrowRange; const auto narrowRange = quantParam->narrowRange;
const double maxLimit = static_cast<float>((1 << (unsigned int)numBit) - 1 - zeroPoint) * scale; const int maxLimit = quant_max;
double minLimit; const int minLimit = quant_min;
if (narrowRange) {
minLimit = static_cast<float>(1 - zeroPoint) * scale;
} else {
minLimit = static_cast<float>(0 - zeroPoint) * scale;
}
return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] { return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
double tmp = 0.0f; int quant_data = std::round(originData / scale + zeroPoint);
if (originData > maxLimit) { if (quant_data > maxLimit) {
tmp = maxLimit; quant_data = maxLimit;
} else if (originData < minLimit) { } else if (quant_data < minLimit) {
tmp = minLimit; quant_data = minLimit;
} else {
tmp = originData;
}
auto quantData = static_cast<T>(std::round(tmp / scale + zeroPoint));
if (quantData == 0 && narrowRange) {
quantData++;
} }
return quantData; return static_cast<T>(quant_data);
}(); }();
} }
void CalFakeNode(const AnfNodePtr &inTensor); void CalFakeNode(const AnfNodePtr &inTensor);
STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min,
size_t bitNum = UINT8_QUANTIZATION); size_t bitNum = UINT8_QUANTIZATION, bool per_channel = false);
STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION); STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册