提交 1fc91411 编写于 作者: X xutianchun

Post Training Quantization

上级 0df5a561
......@@ -177,31 +177,44 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
if (node->quantType == schema::QuantType_PostTraining) {
MS_LOG(INFO) << "node: " << node->name << " add QuantParam";
// activation
auto activate_index = node->inputIndex[0];
auto tensor_input = metaGraphT->allTensors[activate_index].get();
auto input_quant_params = primitiveT_value->GetInputQuantParams();
if (input_quant_params.empty()) {
MS_LOG(WARNING) << "node: " << node->name
<< " input quant params is empty";
} else {
auto node_type = primitiveT_value->GetPrimitiveT()->value.type;
for (int i = 0; i < input_quant_params.size(); i++) {
if (i >= node->inputIndex.size()) {
MS_LOG(ERROR) << "node: " << node->name << " input has " << input_quant_params.size()
<< " quant_params; but only " << node->inputIndex.size() << " input";
break;
}
auto activate_index = node->inputIndex[i];
auto tensor_input = metaGraphT->allTensors[activate_index].get();
std::unique_ptr<schema::QuantParamT> input_quant_param =
std::make_unique<schema::QuantParamT>(input_quant_params[0]);
std::make_unique<schema::QuantParamT>(input_quant_params[i]);
MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param->scale
<< " zp: " << input_quant_param->zeroPoint;
tensor_input->quantParams.emplace_back(std::move(input_quant_param));
if (!(node_type == schema::PrimitiveType_QuantDTypeCast &&
primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->srcT == kNumberTypeFloat32)) {
tensor_input->dataType = kNumberTypeInt8;
}
}
tensor_input->dataType = kNumberTypeInt8;
// output
auto output_index = node->outputIndex[0];
auto tensor_output = metaGraphT->allTensors[output_index].get();
auto output_quant_params = primitiveT_value->GetOutputQuantParams();
if (output_quant_params.empty()) {
MS_LOG(WARNING) << "node: " << node->name
<< " output quant params is empty";
MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty";
} else {
std::unique_ptr<schema::QuantParamT> output_quant_param =
std::make_unique<schema::QuantParamT>(output_quant_params[0]);
std::make_unique<schema::QuantParamT>(output_quant_params[0]);
MS_LOG(DEBUG) << "[output]node: " << node->name << " scale: " << output_quant_param->scale
<< " zp: " << output_quant_param->zeroPoint;
tensor_output->quantParams.emplace_back(std::move(output_quant_param));
}
tensor_output->dataType = kNumberTypeInt8;
if (!(node_type == schema::PrimitiveType_QuantDTypeCast &&
primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) {
tensor_output->dataType = kNumberTypeInt8;
}
// // TensorType
// valuePtr = primitive->GetAttr(kInputTensorDataType);
// if (valuePtr != nullptr) {
......
......@@ -64,6 +64,16 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
// no copy data, do copy when call LiteKernel::Init
dstTensor->SetData(const_cast<unsigned char *>(srcTensor->data()->data()));
}
auto quant_params = srcTensor->quantParams();
if (quant_params != nullptr) {
for (int j = 0; j < quant_params->size(); j++) {
tensor::QuantArg quant_arg{};
quant_arg.scale = quant_params->Get(j)->scale();
quant_arg.zeroPoint = quant_params->Get(j)->zeroPoint();
dstTensor->AddQuantParam(quant_arg);
}
}
this->tensors.emplace_back(dstTensor);
}
return RET_OK;
......
......@@ -30,6 +30,7 @@ int QuantDTypeCast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
auto param = primitive->value_as_QuantDTypeCast();
MS_ASSERT(input->data_type() == param->srcT);
output->set_data_type(static_cast<TypeId>(param->dstT()));
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace mindspore::lite
......@@ -58,7 +58,7 @@ int QuantDTypeCastCPUKernel::Init() {
}
inverse_ = true;
} else {
MS_LOG(ERROR) << "param data type not supported.";
MS_LOG(ERROR) << "param data type not supported:" << " src: " << param->srcT << " dst: " << param->dstT;
return RET_ERROR;
}
......@@ -143,7 +143,6 @@ kernel::LiteKernel *CpuQuantDTypeCastFp32KernelCreator(const std::vector<lite::t
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator)
} // namespace mindspore::kernel
......@@ -23,7 +23,7 @@ int DequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_
}
for (int i = 0; i < size; ++i) {
real_values[i] = (quant_values[i] + zp) * scale;
real_values[i] = (quant_values[i] - zp) * scale;
}
return NNACL_OK;
}
......@@ -34,7 +34,14 @@ int QuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_
}
for (int i = 0; i < size; ++i) {
quant_values[i] = (int8_t)round(real_values[i] / scale + zp);
float temp = round(real_values[i] / scale + zp);
if (temp > 127) {
quant_values[i] = 127;
} else if (temp < -128) {
quant_values[i] = -128;
} else {
quant_values[i] = (int8_t)temp;
}
}
return NNACL_OK;
}
......@@ -166,6 +166,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
return -1;
}
}
MS_LOG(DEBUG) << "weight_tensor_format: " << weightTensor->format;
return 0;
} else if (fmkType == converter::FmkType_ONNX) {
switch (node->quantType) {
......@@ -217,7 +218,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
auto opType = node->primitive->value.type;
if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D &&
opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) {
return 0;
return RET_OK;
}
MS_ASSERT(node->inputIndex.size() >= 2);
......@@ -225,7 +226,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
MS_ASSERT(subGraph->allTensors.size() > weightIndex);
auto &weightTensor = subGraph->allTensors[weightIndex];
MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT
STATUS status;
STATUS status = RET_OK;
if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK
if (weightTensor->format == schema::Format_KCHW) { // from caffe
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
......@@ -238,11 +239,12 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK);
}
} else if (weightTensor->format == schema::Format_KHWC) { // from onnx
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK);
} else {
status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK);
}
return RET_OK;
// if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
// status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK);
// } else {
// status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK);
// }
} else if (weightTensor->format == schema::Format_HWCK) { // from tf
return 0;
} else {
......@@ -273,8 +275,8 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
} else if (weightTensor->format == schema::Format_HWCK) { // from tf
return 0;
} else if (weightTensor->format == schema::Format_CHWK) { // from onnx
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCHWK2HWCK);
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<int8_t>(weightTensor.get(), kCHWK2KHWC);
} else {
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2HWCK);
}
......
......@@ -54,7 +54,7 @@ struct DivergInfo {
size_t bit_num;
int quant_max = 255;
int quant_min = 0;
DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max = 255, int quant_min = 0) {
DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min) {
this->cnode = cnode;
this->bin_num = bins;
this->bit_num = bits;
......@@ -81,6 +81,9 @@ struct DivergInfo {
STATUS UpdateHistogram(const std::vector<float> &data, const std::vector<int> &shape) {
for (auto value : data) {
if (value == 0) {
continue;
}
int bin_index = std::min(static_cast<int>(std::fabs(value) / this->interval), bin_num - 1);
this->histogram[bin_index]++;
}
......@@ -470,8 +473,10 @@ STATUS Calibrator::ReadConfig() {
Calibrator::Calibrator(string path, size_t bitNum, int quantMax, int quantMin)
: config_path_(path), bit_num_(bitNum), quant_max_(quantMax), quant_min_(quantMin) {}
PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type)
PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type,
bool per_channel)
: Quantizer(graph) {
this->per_channel_ = per_channel;
this->bit_num = bit_num;
this->target_type_ = target_type;
if (target_type == kNumberTypeInt8) {
......@@ -533,7 +538,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr node) {
}
auto parameter = std::dynamic_pointer_cast<Parameter>(node);
ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param());
auto status = QuantFilter(paramValue, QuantType_PostTraining, quant_max, quant_min, bit_num);
auto status = QuantFilter(paramValue, QuantType_PostTraining, quant_max, quant_min, bit_num, per_channel_);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status;
return status;
......@@ -670,18 +675,32 @@ STATUS PostTrainingQuantizer::QuantNode() {
MS_LOG(ERROR) << "PrimitiveT_value is nullptr";
continue;
}
if (input_scale.find(cnode) == input_scale.end()) {
primitiveT_value->SetQuantType(schema::QuantType_QUANT_NONE);
continue;
}
auto input_vec = cnode->inputs();
auto op_name = cnode->fullname_with_scope();
auto op_type = primitiveT_value->GetPrimitiveT()->value.type;
MS_LOG(INFO) << "OpName: " << op_name;
if (input_vec.size() <= 3 && op_name != "Conv2D" && op_name != "DepthwiseConv2D") {
MS_LOG(INFO) << "todo(x): ";
// int32_t qnodeOutputZeropoint = outputZeropoint[cnode];
// p->AddAttr(kInputTensorDataType, MakeValue((int)targetType));
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D) {
for (auto i = 1; i < cnode->inputs().size(); i++) {
auto input_node = cnode->input(i);
if (!input_node->isa<mindspore::CNode>()) {
MS_LOG(WARNING) << "node: " << cnode_name << " input " << i << " not a cnode";
continue;
}
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
auto input_cnode_primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(input_cnode->input(0));
if (input_cnode_primitiveT_value == nullptr) {
MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": "
<< " PrimitiveTValue is null";
continue;
}
for (auto &quant_param : input_cnode_primitiveT_value->GetOutputQuantParams()) {
primitiveT_value->AddInputQuantParam(quant_param);
}
}
} else {
// do input quant
double scale = input_scale[cnode];
......
......@@ -55,15 +55,18 @@ struct ConfigParam {
class PostTrainingQuantizer : public Quantizer {
public:
PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8);
PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8,
bool per_channel = false);
STATUS DoQuantize(FuncGraphPtr funcGraph) override;
size_t bit_num;
int quant_max{255};
int quant_min{0};
int quant_max{127};
int quant_min{-128};
private:
bool per_channel_;
TypeId target_type_{kNumberTypeInt8};
std::unique_ptr<Calibrator> calibrator_;
......
......@@ -25,10 +25,11 @@ namespace mindspore::lite::quant {
ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params) {
std::unique_ptr<schema::PrimitiveT> primitive = std::make_unique<schema::PrimitiveT>();
schema::QuantDTypeCastT quant_dtype_cast;
quant_dtype_cast.srcT = src_type; // kNumberTypeUInt8;
quant_dtype_cast.srcT = src_type; // kNumberTypeInt8;
quant_dtype_cast.dstT = dst_type; // kNumberTypeFloat32;
primitive->value.Set(quant_dtype_cast);
auto primTValue = std::make_shared<PrimitiveTValue>(primitive.release());
primTValue->SetQuantType(schema::QuantType_PostTraining);
for (auto &quant_param : quant_params) {
primTValue->AddInputQuantParam(quant_param);
}
......@@ -52,7 +53,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
if (first) {
if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) {
auto value_node =
NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8, primitiveT_value->GetInputQuantParams());
NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams());
std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)};
auto quant_cast_cnode = graph->NewCNode(op_inputs);
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast");
......@@ -82,11 +83,11 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
ValueNodePtr value_node = nullptr;
if (curnode_quant_type == schema::QuantType_PostTraining &&
input_cnode_quant_type == schema::QuantType_QUANT_NONE) {
value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8,
input_cnode_primitiveT_value->GetInputQuantParams());
value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8,
primitiveT_value->GetInputQuantParams());
} else if (curnode_quant_type == schema::QuantType_QUANT_NONE &&
input_cnode_quant_type == schema::QuantType_PostTraining) {
value_node = NewQuantCastValueNode(kNumberTypeUInt8, kNumberTypeFloat32,
value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32,
input_cnode_primitiveT_value->GetInputQuantParams());
}
if (value_node == nullptr) {
......
......@@ -98,7 +98,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
static const std::vector<schema::PrimitiveType> uint8OpList = {
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax, schema::PrimitiveType_Reshape,
schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/ schema::PrimitiveType_Reshape,
schema::PrimitiveType_Activation};
return IsContain(uint8OpList, type);
}
......@@ -242,64 +242,122 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double
return RET_OK;
}
STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum) {
STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum,
bool per_channel) {
if (per_channel) {
// per channel
auto dims = weightPtr->tensor_shape();
if (dims.size() < 1) {
MS_LOG(ERROR) << "weight dims size error";
return RET_ERROR;
MS_LOG(ERROR) << "weight dims size error";
return RET_ERROR;
}
uint32_t channels = dims[0];
// todo(x)
uint32_t channels = dims[3];
if (channels == 0) {
MS_LOG(ERROR) << "channels error 0";
return RET_ERROR;
MS_LOG(ERROR) << "channels error 0";
return RET_ERROR;
}
size_t shapeSize = weightPtr->tensor_shape_size();
size_t oneFilterSize = shapeSize / channels;
auto *rawDatas = reinterpret_cast<const float *>(weightPtr->tensor_addr());
if (rawDatas == nullptr) {
MS_LOG(ERROR) << "rawDatas is nullptr";
return RET_ERROR;
MS_LOG(ERROR) << "rawDatas is nullptr";
return RET_ERROR;
}
weightPtr->quant_param().clear();
vector<uint8_t> qDatas(shapeSize);
vector<int8_t> qDatas(shapeSize);
for (uint32_t i = 0; i < channels; i++) {
float min = 0;
float max = 0;
// find min and max
for (uint32_t j = 0; j < oneFilterSize; j++) {
min = std::min(min, rawDatas[j + i * oneFilterSize]);
max = std::max(max, rawDatas[j + i * oneFilterSize]);
}
float min = 0;
float max = 0;
// find min and max
for (uint32_t j = 0; j < oneFilterSize; j++) {
min = std::min(min, rawDatas[j + i * oneFilterSize]);
max = std::max(max, rawDatas[j + i * oneFilterSize]);
}
std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam);
STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum);
if (status != RET_OK) {
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
return status;
}
// update data and datatype
for (uint32_t j = 0; j < oneFilterSize; j++) {
float rawData = rawDatas[j + i * oneFilterSize];
auto qData = QuantizeData<int8_t>(rawData, quantParam.get(), quant_max, quant_min);
qDatas[j + i * oneFilterSize] = qData;
}
weightPtr->set_quant_param(quantParam);
}
auto ret = memcpy_s(const_cast<float*>(rawDatas), weightPtr->tensor_size(),
qDatas.data(), shapeSize * sizeof(int8_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;
}
if (quantType == QuantType_WeightQuant) {
PostBitPack(const_cast<float*>(rawDatas), shapeSize, bitNum);
}
std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam);
STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum);
if (status != RET_OK) {
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
return status;
}
// update data and datatype
for (uint32_t j = 0; j < oneFilterSize; j++) {
float rawData = rawDatas[j + i * oneFilterSize];
auto qData = QuantizeData<uint8_t>(rawData, quantParam.get());
qDatas[j + i * oneFilterSize] = qData;
}
weightPtr->set_tensor_type(kNumberTypeInt8);
weightPtr->set_tensor_size(shapeSize * sizeof(int8_t));
} else {
// per layer
size_t shapeSize = weightPtr->tensor_shape_size();
auto *rawDatas = static_cast<float *>(weightPtr->tensor_addr());
if (rawDatas == nullptr) {
MS_LOG(ERROR) << "rawDatas is nullptr";
return RET_ERROR;
}
weightPtr->set_quant_param(quantParam);
weightPtr->quant_param().clear();
vector<int8_t> qDatas(shapeSize);
float min = 0;
float max = 0;
for (uint32_t i = 0; i < shapeSize; i++) {
// find max min
min = std::min(min, rawDatas[i]);
max = std::max(max, rawDatas[i]);
}
auto ret = memcpy_s(const_cast<float*>(rawDatas), weightPtr->tensor_size(),
qDatas.data(), shapeSize * sizeof(uint8_t));
std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam);
STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum);
if (status != RET_OK) {
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
return status;
}
// update data and datatype
for (uint32_t i = 0; i < shapeSize; i++) {
float rawData = rawDatas[i];
auto quant_data = std::round(rawData / quantParam->scale + quantParam->zeroPoint);
if (quant_data > quant_max) {
qDatas[i] = quant_max;
} else if (quant_data < quant_min) {
qDatas[i] = quant_min;
} else {
qDatas[i] = static_cast<int8_t>(quant_data);
}
}
weightPtr->set_quant_param(quantParam);
auto ret = memcpy_s(rawDatas, weightPtr->tensor_size() * sizeof(int8_t),
qDatas.data(), shapeSize * sizeof(int8_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;
}
if (quantType == QuantType_WeightQuant) {
PostBitPack(const_cast<float*>(rawDatas), shapeSize, bitNum);
PostBitPack(rawDatas, shapeSize, bitNum);
}
weightPtr->set_tensor_type(kNumberTypeInt8);
weightPtr->set_tensor_size(shapeSize * sizeof(int8_t));
}
return RET_OK;
}
......
......@@ -63,41 +63,30 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double
bool narrowRange, int quant_max, int quant_min, int num_bits);
template <typename T>
T QuantizeData(const float originData, const AnfQuantParam *quantParam) {
T QuantizeData(float originData, const AnfQuantParam *quantParam, int quant_max, int quant_min) {
MS_ASSERT(quantParam != nullptr);
MS_ASSERT(quantParam->inited);
const auto scale = quantParam->scale;
const auto zeroPoint = quantParam->zeroPoint;
const auto numBit = quantParam->numBits;
const int zeroPoint = quantParam->zeroPoint;
const auto narrowRange = quantParam->narrowRange;
const double maxLimit = static_cast<float>((1 << (unsigned int)numBit) - 1 - zeroPoint) * scale;
double minLimit;
if (narrowRange) {
minLimit = static_cast<float>(1 - zeroPoint) * scale;
} else {
minLimit = static_cast<float>(0 - zeroPoint) * scale;
}
const int maxLimit = quant_max;
const int minLimit = quant_min;
return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
double tmp = 0.0f;
if (originData > maxLimit) {
tmp = maxLimit;
} else if (originData < minLimit) {
tmp = minLimit;
} else {
tmp = originData;
}
auto quantData = static_cast<T>(std::round(tmp / scale + zeroPoint));
if (quantData == 0 && narrowRange) {
quantData++;
int quant_data = std::round(originData / scale + zeroPoint);
if (quant_data > maxLimit) {
quant_data = maxLimit;
} else if (quant_data < minLimit) {
quant_data = minLimit;
}
return quantData;
return static_cast<T>(quant_data);
}();
}
void CalFakeNode(const AnfNodePtr &inTensor);
STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min,
size_t bitNum = UINT8_QUANTIZATION);
size_t bitNum = UINT8_QUANTIZATION, bool per_channel = false);
STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册