提交 6e759cd4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3815 compute threshold only once in post training quantization

Merge pull request !3815 from xutianchun/quant_0731
......@@ -98,7 +98,6 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
}
node->primitive = std::unique_ptr<schema::PrimitiveT>(primitiveT_value->GetPrimitiveT());
primitiveT_value->SetPrimitiveT(nullptr);
std::vector<schema::TensorT *> outputs;
SetOpInputNode(cnode, metaGraphT.get(), node.get());
SetOpOutputNode(outputs, metaGraphT.get(), node.get());
......@@ -113,24 +112,22 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
auto input_quant_params = primitiveT_value->GetInputQuantParams();
if (input_quant_params.empty()) {
MS_LOG(WARNING) << "node: " << node->name << " input quant params is empty";
continue;
} else {
std::unique_ptr<schema::QuantParamT> input_quant_param =
std::make_unique<schema::QuantParamT>(input_quant_params[0]);
tensor_input->quantParams.emplace_back(std::move(input_quant_param));
}
std::unique_ptr<schema::QuantParamT> input_quant_param =
std::make_unique<schema::QuantParamT>(input_quant_params[0]);
tensor_input->quantParams.emplace_back(std::move(input_quant_param));
// output
auto output_index = node->outputIndex[0];
auto tensor_output = metaGraphT->allTensors[output_index].get();
auto output_quant_params = primitiveT_value->GetOutputQuantParams();
if (output_quant_params.empty()) {
MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty";
continue;
} else {
std::unique_ptr<schema::QuantParamT> output_quant_param =
std::make_unique<schema::QuantParamT>(output_quant_params[0]);
tensor_output->quantParams.emplace_back(std::move(output_quant_param));
}
std::unique_ptr<schema::QuantParamT> output_quant_param =
std::make_unique<schema::QuantParamT>(output_quant_params[0]);
tensor_output->quantParams.emplace_back(std::move(output_quant_param));
// // TensorType
// valuePtr = primitive->GetAttr(kInputTensorDataType);
// if (valuePtr != nullptr) {
......
......@@ -26,8 +26,8 @@ namespace mindspore::lite {
class PrimitiveTValue : public Value {
public:
explicit PrimitiveTValue(schema::PrimitiveT *primt) : primitive(primt) {}
~PrimitiveTValue() override { delete this->primitive; }
// not responsible to free primitive, the one created the dynamic memory is responsible to free it.
~PrimitiveTValue() override = default;
MS_DECLARE_PARENT(PrimitiveTValue, Value)
......
......@@ -27,7 +27,7 @@ int WeightFormatPass::Run(GraphNode *graphNode) {
MS_LOG(ERROR) << "ShapeFormatTrans failed: " << status;
return status;
}
if (this->quantType == QuantType_AwareTrainning) {
if (this->quantType == QuantType_AwareTrainning || this->quantType == QuantType_PostTraining) {
status = QuantDataFormatTrans(graphNode);
if (status != 0) {
MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status;
......@@ -147,7 +147,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
} else if (fmkType == converter::FmkType_TFLITE) {
switch (node->quantType) {
case QuantType_QUANT_NONE:
case QuantType_AwareTrainning: {
case QuantType_AwareTrainning:
case QuantType_PostTraining: {
if (opType == schema::PrimitiveType_Conv2D) {
weightTensor->format = schema::Format_KHWC;
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) {
......
......@@ -292,13 +292,32 @@ STATUS Calibrator::RecordMaxValue(std::string opName, vector<float> data,
}
STATUS Calibrator::ComputeThreshold() {
for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) {
for (auto iter = this->output_diverg_info_.begin(); iter != this->output_diverg_info_.end(); iter++) {
DivergInfo *info = iter->second.get();
info->ComputeThreshold();
}
for (auto iter = this->output_diverg_info_.begin(); iter != this->output_diverg_info_.end(); iter++) {
// node A's input may be node B's output, no need to re-compute the node A's input quant param which is the same as
for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) {
DivergInfo *info = iter->second.get();
info->ComputeThreshold();
auto cnode = info->cnode;
bool already_computed = false;
auto input = cnode->input(1);
if (input->isa<mindspore::CNode>()) {
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input);
for (const auto &output_diverg_info : output_diverg_info_) {
auto output_diverg_cnode = output_diverg_info.second->cnode;
if (output_diverg_cnode == input_cnode) {
*info = *(output_diverg_info.second);
info->cnode = cnode;
already_computed = true;
break;
}
}
}
if (!already_computed) {
info->ComputeThreshold();
}
}
return RET_OK;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册