未验证 提交 4e4f4586 编写于 作者: Y yeliang2258 提交者: GitHub

New format quant model support for MKLDNN (#45416)

* support onnx format quantized model

* update code

* add test

* add test

* fix

* fix test

* fix cmake

* update code

* change scale file path to calibration file path

* update code

* update code

* fix build bug

* fix build bugs

* fix

* fix
上级 fd56f08e
......@@ -96,7 +96,49 @@ void QuantDequantMkldnnPass::CollectInfoFromFake(
}
}
void QuantDequantMkldnnPass::CollectInputScalesFromFake(
void QuantDequantMkldnnPass::CollectWeightScalesInfoFromONNXFormatDequantize(
ir::Graph* graph,
Scope* scope,
std::unordered_map<std::string, std::vector<float>>* weight_thresholds,
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
bool* onnx_format_quantize_model) const {
VLOG(3) << "gather weight_thresholds from onnx format dequantized ops";
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp()) continue;
if (op_node->Name() == "dequantize_linear") {
auto* op_desc = op_node->Op();
auto x_var_name = op_desc->Input("X")[0];
auto* weight_var = scope->FindVar(x_var_name);
if (!weight_var) {
auto out_var_name = op_desc->Output("Y")[0];
if (var_quant_scales->count(x_var_name) &&
!var_quant_scales->count(out_var_name)) {
std::vector<float> scale_v = var_quant_scales->at(x_var_name);
var_quant_scales->insert(std::make_pair(out_var_name, scale_v));
}
} else {
*onnx_format_quantize_model = true;
auto scale_name = op_desc->Input("Scale")[0];
auto* var = scope->FindVar(scale_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"The Scales variable [%s] of dequantize op is not found.",
var));
auto* scale_tensor = var->GetMutable<LoDTensor>();
auto* scale_data = scale_tensor->data<float>();
std::vector<float> thresholds(scale_data,
scale_data + scale_tensor->numel());
weight_thresholds->insert(std::make_pair(x_var_name, thresholds));
}
}
}
}
void QuantDequantMkldnnPass::CollectInputScalesFromQuantize(
ir::Graph* graph,
Scope* scope,
const std::unordered_set<std::string>& fake_quantize_types,
......@@ -108,6 +150,7 @@ void QuantDequantMkldnnPass::CollectInputScalesFromFake(
if (!op_node->IsOp()) continue;
if (op_node->Name() == "fake_quantize_dequantize_moving_average_abs_max" ||
op_node->Name() == "quantize_linear" ||
fake_quantize_types.count(op_node->Name())) {
auto* op_desc = op_node->Op();
const int bit_length =
......@@ -119,10 +162,17 @@ void QuantDequantMkldnnPass::CollectInputScalesFromFake(
"bits: %d, only 8 is supported now.",
bit_length));
std::string scale_name = "InScale";
std::string out_name = "Out";
if (op_node->Name() == "quantize_linear") {
scale_name = "Scale";
out_name = "Y";
}
auto x_var_name = op_desc->Input("X")[0];
auto scale_name = op_desc->Input("InScale")[0];
auto out_var_name = op_desc->Output("Out")[0];
auto* var = scope->FindVar(scale_name);
auto scale_var_name = op_desc->Input(scale_name)[0];
auto out_var_name = op_desc->Output(out_name)[0];
auto* var = scope->FindVar(scale_var_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
......@@ -275,12 +325,66 @@ void QuantDequantMkldnnPass::CollectFakeDequantizeOps(
nodes2rm->insert(fake_dequant_out);
}
void QuantDequantMkldnnPass::CollectQuantizeDequantizeOpsFromONNXFormat(
ir::Graph* graph,
Node* op_node,
std::unordered_set<const Node*>* nodes2rm) const {
auto* op_desc = op_node->Op();
auto x_var_name = op_desc->Input("X")[0];
auto in_scale_name = op_desc->Input("Scale")[0];
auto in_zero_name = op_desc->Input("ZeroPoint")[0];
auto out_var_name = op_desc->Output("Y")[0];
Node* fake_quant_in = nullptr;
Node* fake_quant_in_scale = nullptr;
for (auto* node_input : op_node->inputs) {
if (node_input->Name() == x_var_name) {
fake_quant_in = node_input;
} else if (node_input->Name() == in_scale_name) {
fake_quant_in_scale = node_input;
}
}
Node* fake_quant_out = nullptr;
for (auto* node_output : op_node->outputs) {
if (node_output->Name() == out_var_name) {
fake_quant_out = node_output;
}
}
PADDLE_ENFORCE_NOT_NULL(
fake_quant_in,
platform::errors::NotFound(
"The input var [%s] of quantize op is not found.", x_var_name));
PADDLE_ENFORCE_NOT_NULL(
fake_quant_in_scale,
platform::errors::NotFound(
"The scale var [%s] of quantize op is not found.", in_scale_name));
PADDLE_ENFORCE_NOT_NULL(
fake_quant_out,
platform::errors::NotFound(
"The output var [%s] of quantize op is not found.", out_var_name));
std::string input_act_name = fake_quant_in->Var()->Name();
std::string output_act_name = fake_quant_out->Var()->Name();
for (auto* next_node : fake_quant_out->outputs) {
if (!next_node->IsOp()) continue;
next_node->Op()->RenameInput(output_act_name, input_act_name);
IR_NODE_LINK_TO(fake_quant_in, next_node);
}
nodes2rm->insert(op_node);
nodes2rm->insert(fake_quant_in_scale);
nodes2rm->insert(fake_quant_out);
}
void QuantDequantMkldnnPass::RemoveFakeOps(
ir::Graph* graph,
const std::unordered_set<std::string>& fake_quantize_types,
const std::unordered_set<std::string>& fake_dequantize_types,
const std::unordered_set<std::string>& fake_quantize_dequantize_types)
const {
const std::unordered_set<std::string>& fake_quantize_dequantize_types,
const std::unordered_set<std::string>&
onnx_format_quantize_dequantize_types) const {
VLOG(3) << "remove fake quantize and dequantize ops";
std::unordered_set<const Node*> nodes2rm = {};
......@@ -294,6 +398,8 @@ void QuantDequantMkldnnPass::RemoveFakeOps(
CollectFakeDequantizeOps(graph, op_node, &nodes2rm);
} else if (fake_quantize_dequantize_types.count(op_node->Name())) {
CollectFakeDequantizeOps(graph, op_node, &nodes2rm);
} else if (onnx_format_quantize_dequantize_types.count(op_node->Name())) {
CollectQuantizeDequantizeOpsFromONNXFormat(graph, op_node, &nodes2rm);
}
}
......@@ -357,64 +463,54 @@ bool QuantDequantMkldnnPass::IsInt8Weight(
return is_int8;
}
void QuantDequantMkldnnPass::DequantizeOpWeights(
Node* op_node,
Scope* scope,
const std::string& weight_name,
const std::string& output_name,
const std::unordered_map<std::string, std::vector<float>>&
weight_thresholds) const {
auto* op_desc = op_node->Op();
std::string weight_var_name = op_desc->Input(weight_name)[0];
std::string output_var_name = op_desc->Output(output_name)[0];
std::vector<float> scales;
auto iter = weight_thresholds.find(output_var_name);
if (iter != weight_thresholds.end()) {
scales = iter->second;
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Could not find threshold information for [%s] var, please check if "
"the model is correct.",
output_var_name));
}
auto* var = scope->FindVar(weight_var_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"The input persistable [%s] var of [%s] op is not found.",
weight_var_name,
op_desc->Type()));
auto* weight_tensor = var->GetMutable<LoDTensor>();
void QuantDequantMkldnnPass::ConvertFromINT8ToFP32(
const std::vector<float>& scales,
Tensor* weight_tensor,
int8_t* int8_weight_data,
float* fp32_weight_data,
const std::string& weight_var_name) const {
const auto weight_dims = weight_tensor->dims();
std::vector<float> weight_data;
weight_data.resize(weight_tensor->numel());
const int size = scales.size();
if (size == 1 || size == weight_dims[0]) {
auto* weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < weight_tensor->numel(); i++) {
weight_data[i] /= 127;
if (int8_weight_data) {
weight_data[i] = static_cast<float>(int8_weight_data[i]) / 127.0;
} else {
weight_data[i] = fp32_weight_data[i] / 127.0;
}
}
weight_tensor->clear(); // clear int weight
weight_tensor->Resize(phi::make_ddim(phi::vectorize(weight_dims)));
auto* new_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_weight_data,
weight_data.data(),
weight_tensor->numel() * sizeof(float));
TransposeWeight(weight_tensor);
if (size == 1) {
for (int i = 0; i < weight_tensor->numel(); i++) {
weight_data[i] *= scales[0];
new_weight_data[i] *= scales[0];
}
} else {
for (int i = 0; i < weight_tensor->numel(); i++) {
weight_data[i] *= scales[i % size];
new_weight_data[i] *= scales[i % size];
}
}
TransposeWeight(weight_tensor);
} else if (weight_dims.size() > 1 && size == weight_dims[1]) {
auto* weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < weight_tensor->numel(); i++) {
weight_data[i] /= 127;
if (int8_weight_data) {
weight_data[i] = static_cast<float>(int8_weight_data[i]) / 127.0;
} else {
weight_data[i] = fp32_weight_data[i] / 127.0;
}
}
int step_n = 1;
......@@ -433,6 +529,13 @@ void QuantDequantMkldnnPass::DequantizeOpWeights(
}
}
}
weight_tensor->clear(); // clear int weight
weight_tensor->Resize(phi::make_ddim(phi::vectorize(weight_dims)));
auto* new_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_weight_data,
weight_data.data(),
weight_tensor->numel() * sizeof(float));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The size of weight scales vector (%d) does not "
......@@ -441,15 +544,89 @@ void QuantDequantMkldnnPass::DequantizeOpWeights(
weight_tensor->dims().size(),
weight_var_name));
}
weight_tensor->Resize(weight_dims);
}
void QuantDequantMkldnnPass::DequantizeOpWeights(
Node* op_node,
Scope* scope,
const std::string& weight_name,
const std::string& output_name,
const std::unordered_map<std::string, std::vector<float>>&
weight_thresholds) const {
auto* op_desc = op_node->Op();
std::string weight_var_name = op_desc->Input(weight_name)[0];
std::string output_var_name = op_desc->Output(output_name)[0];
std::vector<float> scales;
auto iter = weight_thresholds.find(output_var_name);
if (iter != weight_thresholds.end()) {
scales = iter->second;
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Could not find threshold information for [%s] var, please check if "
"the model is correct.",
output_var_name));
}
auto* var = scope->FindVar(weight_var_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"The input persistable [%s] var of [%s] op is not found.",
weight_var_name,
op_desc->Type()));
auto* weight_tensor = var->GetMutable<LoDTensor>();
float* fp32_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
ConvertFromINT8ToFP32(
scales, weight_tensor, nullptr, fp32_weight_data, weight_var_name);
}
void QuantDequantMkldnnPass::DequantizeOpWeightsFromONNXFormat(
Node* op_node,
Scope* scope,
const std::string& weight_name,
const std::unordered_map<std::string, std::vector<float>>&
weight_thresholds) const {
auto* op_desc = op_node->Op();
std::string weight_var_name = op_desc->Input(weight_name)[0];
std::vector<float> scales;
auto iter = weight_thresholds.find(weight_var_name);
if (iter != weight_thresholds.end()) {
scales = iter->second;
} else {
if (!IsInt8Weight(op_node, scope, weight_name)) {
return;
}
PADDLE_THROW(paddle::platform::errors::Fatal(
"Could not find threshold information for [%s] var, please check if "
"the model is correct.",
weight_var_name));
}
auto* var = scope->FindVar(weight_var_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"The input persistable [%s] var of [%s] op is not found.",
weight_var_name,
op_desc->Type()));
auto* weight_tensor = var->GetMutable<LoDTensor>();
int8_t* int8_weight_data =
weight_tensor->mutable_data<int8_t>(platform::CPUPlace());
ConvertFromINT8ToFP32(
scales, weight_tensor, int8_weight_data, nullptr, weight_var_name);
}
void QuantDequantMkldnnPass::DequantizeWeights(
ir::Graph* graph,
Scope* scope,
const std::unordered_map<std::string, std::vector<float>>&
weight_thresholds) const {
weight_thresholds,
const bool& onnx_format_quantize_model) const {
VLOG(3) << "dequantize weight for ops which has weight";
if (weight_thresholds.empty()) {
......@@ -462,13 +639,19 @@ void QuantDequantMkldnnPass::DequantizeWeights(
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp()) continue;
if (op_node->Name() == "conv2d" || op_node->Name() == "depthwise_conv2d") {
if (IsInt8Weight(op_node, scope, "Filter")) {
if (onnx_format_quantize_model) {
DequantizeOpWeightsFromONNXFormat(
op_node, scope, "Filter", weight_thresholds);
} else if (IsInt8Weight(op_node, scope, "Filter")) {
DequantizeOpWeights(
op_node, scope, "Filter", "Output", weight_thresholds);
}
} else if (op_node->Name() == "mul" || op_node->Name() == "matmul" ||
op_node->Name() == "matmul_v2") {
if (IsInt8Weight(op_node, scope, "Y")) {
if (onnx_format_quantize_model) {
DequantizeOpWeightsFromONNXFormat(
op_node, scope, "Y", weight_thresholds);
} else if (IsInt8Weight(op_node, scope, "Y")) {
DequantizeOpWeights(op_node, scope, "Y", "Out", weight_thresholds);
}
}
......@@ -526,20 +709,34 @@ void QuantDequantMkldnnPass::ApplyImpl(ir::Graph* graph) const {
"fake_quantize_dequantize_moving_average_abs_max",
"fake_channel_wise_quantize_dequantize_abs_max"};
const std::unordered_set<std::string> onnx_format_quantize_dequantize_types =
{"quantize_linear", "dequantize_linear"};
std::unordered_map<std::string, std::vector<float>> weight_thresholds{};
std::unordered_map<std::string, std::vector<float>> var_quant_scales{};
bool onnx_format_quantize_model = false;
auto* scope = param_scope();
GetInfoFromTheFirstOp(
graph, "has_quant_info", "var_quant_scales", &var_quant_scales);
VLOG(1) << "The nums of scale info from slim txt is: "
<< var_quant_scales.size();
MarkSkipQuantizedOps(graph, skip_ops);
CollectInfoFromFake(graph, scope, fake_dequantize_types, &weight_thresholds);
CollectInputScalesFromFake(
CollectWeightScalesInfoFromONNXFormatDequantize(graph,
scope,
&weight_thresholds,
&var_quant_scales,
&onnx_format_quantize_model);
CollectInputScalesFromQuantize(
graph, scope, fake_quantize_types, &var_quant_scales);
CollectOutputScalesFromAttr(graph, &var_quant_scales);
RemoveFakeOps(graph,
fake_quantize_types,
fake_dequantize_types,
fake_quantize_dequantize_types);
DequantizeWeights(graph, scope, weight_thresholds);
fake_quantize_dequantize_types,
onnx_format_quantize_dequantize_types);
DequantizeWeights(
graph, scope, weight_thresholds, onnx_format_quantize_model);
UpdateActivations(graph);
RemoveCtrlVars(graph);
......
......@@ -43,13 +43,34 @@ class QuantDequantMkldnnPass : public FusePassBase {
std::unordered_map<std::string, std::vector<float>>* weight_thresholds)
const;
void CollectInputScalesFromFake(
///
/// \brief collect scale info for weight from onnx_format dequantize_linear op
/// onnx_format_dequantize_types: the onnx_format dequantize op type
/// weight_thresholds: scale info for weight
/// var_quant_scales: scale info for act
/// onnx_format_quantize_model: recorder if the quantize model is a
/// onnx_format quantize model
///
void CollectWeightScalesInfoFromONNXFormatDequantize(
ir::Graph* graph,
Scope* scope,
std::unordered_map<std::string, std::vector<float>>* weight_thresholds,
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
bool* onnx_format_quantize_model) const;
void CollectInputScalesFromQuantize(
ir::Graph* graph,
Scope* scope,
const std::unordered_set<std::string>& fake_quantize_types,
std::unordered_map<std::string, std::vector<float>>* var_quant_scales)
const;
void ConvertFromINT8ToFP32(const std::vector<float>& scales,
Tensor* weight_tensor,
int8_t* int8_weight_data,
float* fp32_weight_data,
const std::string& weight_var_name) const;
void CollectOutputScalesFromAttr(
ir::Graph* graph,
std::unordered_map<std::string, std::vector<float>>* var_quant_scales)
......@@ -64,12 +85,22 @@ class QuantDequantMkldnnPass : public FusePassBase {
Node* op_node,
std::unordered_set<const Node*>* nodes2rm) const;
///
/// \brief collect all the onnx_format quantize related ops to remove
/// nodes2rm: record all quantize related ops to remove
///
void CollectQuantizeDequantizeOpsFromONNXFormat(
ir::Graph* graph,
Node* op_node,
std::unordered_set<const Node*>* nodes2rm) const;
void RemoveFakeOps(
ir::Graph* graph,
const std::unordered_set<std::string>& fake_quantize_types,
const std::unordered_set<std::string>& fake_dequantize_types,
const std::unordered_set<std::string>& fake_quantize_dequantize_types)
const;
const std::unordered_set<std::string>& fake_quantize_dequantize_types,
const std::unordered_set<std::string>&
onnx_format_quantize_dequantize_types) const;
bool IsInt8Weight(Node* op_node,
Scope* scope,
......@@ -85,11 +116,23 @@ class QuantDequantMkldnnPass : public FusePassBase {
const std::unordered_map<std::string, std::vector<float>>&
weight_thresholds) const;
///
/// \brief Dequantize weight in conv or matmul
/// weight_thresholds: recorded scale info for weight
///
void DequantizeOpWeightsFromONNXFormat(
Node* op_node,
Scope* scope,
const std::string& weight_name,
const std::unordered_map<std::string, std::vector<float>>&
weight_thresholds) const;
void DequantizeWeights(
ir::Graph* graph,
Scope* scope,
const std::unordered_map<std::string, std::vector<float>>&
weight_thresholds) const;
weight_thresholds,
const bool& onnx_format_quantize_model) const;
void UpdateActivations(ir::Graph* graph) const;
......
......@@ -177,6 +177,9 @@ struct Argument {
DECL_ARGUMENT_FIELD(mkldnn_cache_capacity, MkldnnCacheCapacity, int);
#ifdef PADDLE_WITH_MKLDNN
// Calibration file path of quantize model
DECL_ARGUMENT_FIELD(calibration_file_path, CalibrationFilePath, std::string);
// A set of op types to enable their quantized kernels
DECL_ARGUMENT_FIELD(quantize_enabled_op_types,
QuantizeEnabledOpTypes,
......
......@@ -20,6 +20,10 @@
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#endif
namespace paddle {
namespace inference {
namespace analysis {
......@@ -32,6 +36,19 @@ void IrAnalysisPass::RunImpl(Argument* argument) {
auto* the_graph = argument->ReleaseMainGraph();
auto graph = std::unique_ptr<Graph>(the_graph);
#ifdef PADDLE_WITH_MKLDNN
if (argument->Has("calibration_file_path")) {
VLOG(5) << "Calibration file path of quantize model: "
<< argument->calibration_file_path();
std::unordered_map<std::string, std::vector<float>> var_quant_scales{};
ReadCalibrationInfo(argument, &var_quant_scales);
// save var_quant_scales in the first op's attr
// for quant_dequant_mkldnn_pass
SaveInfoInTheFirstOp(
the_graph, "has_quant_info", "var_quant_scales", var_quant_scales);
}
#endif
// Apply passes.
IRPassManager the_ir_manager(argument);
graph = the_ir_manager.Apply(std::move(graph));
......@@ -44,6 +61,40 @@ void IrAnalysisPass::RunImpl(Argument* argument) {
CollectFusionStatis(argument);
}
void IrAnalysisPass::ReadCalibrationInfo(
Argument* argument,
std::unordered_map<std::string, std::vector<float>>* var_quant_scales) {
std::string calibration_file_path;
#ifdef PADDLE_WITH_MKLDNN
if (argument->Has("calibration_file_path")) {
calibration_file_path = argument->calibration_file_path();
}
#endif
if (calibration_file_path.empty()) {
LOG(INFO) << "argument has no calibration_file_path";
return;
}
std::ifstream calibration_file(calibration_file_path);
std::string one_line;
while (getline(calibration_file, one_line)) {
if (one_line.find(" ") != one_line.npos) {
auto pos = one_line.find(" ");
std::string pre_str = one_line.substr(0, pos);
std::string pos_str = one_line.substr(pos);
if (pre_str.size() && pos_str.size()) {
std::string tensor_name = pre_str;
float scale = std::stod(pos_str);
scale = 1.0 / scale;
if (std::isinf(scale) || std::isnan(scale)) {
continue;
}
std::vector<float> scales = {scale};
(*var_quant_scales)[tensor_name] = scales;
}
}
}
}
void IrAnalysisPass::CollectFusionStatis(Argument* argument) {
if (!argument->main_graph().Has(framework::ir::kFuseStatisAttr)) {
LOG(INFO) << "argument has no fuse statis";
......
......@@ -33,6 +33,10 @@ class IrAnalysisPass : public AnalysisPass {
void CollectFusionStatis(Argument* argument);
void ReadCalibrationInfo(
Argument* argument,
std::unordered_map<std::string, std::vector<float>>* var_quant_scales);
std::string repr() const override;
};
......
......@@ -246,6 +246,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(opt_cache_dir_);
CP_MEMBER(prog_file_);
CP_MEMBER(params_file_);
CP_MEMBER(calibration_file_path_);
CP_MEMBER(use_fc_padding_);
// GPU related.
......@@ -516,6 +517,14 @@ void AnalysisConfig::EnableMkldnnInt8(
Update();
}
void AnalysisConfig::SetCalibrationFilePath(
const std::string &calibration_file_path) {
calibration_file_path_ = calibration_file_path;
VLOG(1) << "Set calibration file path of quantize model: " +
calibration_file_path_;
Update();
}
MkldnnQuantizerConfig *AnalysisConfig::mkldnn_quantizer_config() const {
PADDLE_ENFORCE_NOT_NULL(mkldnn_quantizer_config_,
platform::errors::PreconditionNotMet(
......@@ -827,6 +836,8 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << prog_file_;
ss << params_file_;
ss << calibration_file_path_;
ss << use_gpu_;
ss << use_external_stream_;
ss << exec_stream_;
......@@ -1009,6 +1020,10 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"model_file", prog_file_});
os.InsertRow({"params_file", params_file_});
}
if (!(calibration_file_path_.empty())) {
os.InsertRow({"calibration_file_path", calibration_file_path_});
}
if (model_from_memory_) {
os.InsertRow({"model_from_memory", params_file_});
}
......
......@@ -1194,6 +1194,7 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetQuantizeEnabledOpTypes(config_.quantize_enabled_op_types_);
argument_.SetQuantizeExcludedOpIds(config_.quantize_excluded_op_ids_);
argument_.SetQuantVarScales({});
argument_.SetCalibrationFilePath(config_.calibration_file_path_);
}
#endif
......
......@@ -763,6 +763,18 @@ struct PD_INFER_DECL AnalysisConfig {
///
void EnableMkldnnQuantizer();
///
/// \brief Set the calibration ranges file path of quantize model.
///
///
void SetCalibrationFilePath(const std::string& calibration_file_path = "");
///
/// \brief Return the calibration ranges file path of quantize model.
///
///
std::string CalibrationFilePath() { return calibration_file_path_; }
///
/// \brief Turn on MKLDNN int8.
///
......@@ -941,6 +953,7 @@ struct PD_INFER_DECL AnalysisConfig {
std::string model_dir_;
mutable std::string prog_file_;
mutable std::string params_file_;
mutable std::string calibration_file_path_;
// Mixed precision.
std::unordered_set<std::string> mixed_black_list_;
......
......@@ -759,6 +759,9 @@ void BindAnalysisConfig(py::module *m) {
.def("to_native_config", &AnalysisConfig::ToNativeConfig)
.def("enable_quantizer", &AnalysisConfig::EnableMkldnnQuantizer)
.def("enable_mkldnn_bfloat16", &AnalysisConfig::EnableMkldnnBfloat16)
.def("set_calibration_file_path",
&AnalysisConfig::SetCalibrationFilePath,
py::arg("calibration_file_path") = std::string(""))
#ifdef PADDLE_WITH_MKLDNN
.def("quantizer_config",
&AnalysisConfig::mkldnn_quantizer_config,
......
文件模式从 100644 更改为 100755
......@@ -4,9 +4,19 @@ file(
"test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
list(REMOVE_ITEM TEST_OPS "test_onnx_format_quantization_mobilenetv1")
if(WITH_MKLDNN AND NOT WIN32)
list(APPEND TEST_OPS "test_onnx_format_quantization_mobilenetv1")
endif()
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach()
set_tests_properties(test_concat_mkldnn_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv3d_mkldnn_op PROPERTIES TIMEOUT 120)
if(WITH_MKLDNN AND NOT WIN32)
set_tests_properties(test_onnx_format_quantization_mobilenetv1
PROPERTIES TIMEOUT 300)
endif()
set_tests_properties(test_flags_mkldnn_ops_on_off PROPERTIES TIMEOUT 120)
# copyright (c) 2022 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import os
import time
import sys
import random
import math
import functools
import contextlib
import tempfile
import numpy as np
from PIL import Image, ImageEnhance
import paddle
import paddle.fluid as fluid
from paddle.dataset.common import download
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
paddle.enable_static()
random.seed(0)
np.random.seed(0)
DATA_DIM = 224
THREAD = 1
BUF_SIZE = 102400
DATA_DIR = 'data/ILSVRC2012'
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), Image.LANCZOS)
return img
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def process_image(sample, mode, color_jitter, rotate):
img_path = sample[0]
img = Image.open(img_path)
img = resize_short(img, target_size=256)
img = crop_image(img, target_size=DATA_DIM, center=True)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255
img -= img_mean
img /= img_std
return img, sample[1]
def _reader_creator(file_list,
mode,
shuffle=False,
color_jitter=False,
rotate=False,
data_dir=DATA_DIR):
def reader():
with open(file_list) as flist:
full_lines = [line.strip() for line in flist]
if shuffle:
np.random.shuffle(full_lines)
lines = full_lines
for line in lines:
img_path, label = line.split()
img_path = os.path.join(data_dir, img_path)
if not os.path.exists(img_path):
continue
yield img_path, int(label)
mapper = functools.partial(process_image,
mode=mode,
color_jitter=color_jitter,
rotate=rotate)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def val(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir)
class TestPostTrainingQuantization(unittest.TestCase):
def setUp(self):
self.int8_download = 'int8/download'
self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
self.int8_download)
self.data_cache_folder = ''
data_urls = []
data_md5s = []
if os.environ.get('DATASET') == 'full':
data_urls.append(
'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partaa'
)
data_md5s.append('60f6525b0e1d127f345641d75d41f0a8')
data_urls.append(
'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partab'
)
data_md5s.append('1e9f15f64e015e58d6f9ec3210ed18b5')
self.data_cache_folder = self.download_data(data_urls, data_md5s,
"full_data", False)
else:
data_urls.append(
'http://paddle-inference-dist.bj.bcebos.com/int8/calibration_test_data.tar.gz'
)
data_md5s.append('1b6c1c434172cca1bf9ba1e4d7a3157d')
self.data_cache_folder = self.download_data(data_urls, data_md5s,
"small_data", False)
# reader/decorator.py requires the relative path to the data folder
if not os.path.exists("./data/ILSVRC2012"):
cmd = 'rm -rf {0} && ln -s {1} {0}'.format("data",
self.data_cache_folder)
os.system(cmd)
self.batch_size = 1 if os.environ.get('DATASET') == 'full' else 50
self.sample_iterations = 50 if os.environ.get(
'DATASET') == 'full' else 2
self.infer_iterations = 50000 if os.environ.get(
'DATASET') == 'full' else 2
self.root_path = tempfile.TemporaryDirectory()
self.int8_model = os.path.join(self.root_path.name,
"post_training_quantization")
print("self.int8_model: ", self.int8_model)
def tearDown(self):
self.root_path.cleanup()
pass
def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder):
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(
target_folder, zip_path)
os.system(cmd)
def download_data(self, data_urls, data_md5s, folder_name, is_model=True):
data_cache_folder = os.path.join(self.cache_folder, folder_name)
zip_path = ''
if os.environ.get('DATASET') == 'full':
file_names = []
for i in range(0, len(data_urls)):
download(data_urls[i], self.int8_download, data_md5s[i])
file_names.append(data_urls[i].split('/')[-1])
zip_path = os.path.join(self.cache_folder,
'full_imagenet_val.tar.gz')
if not os.path.exists(zip_path):
cat_command = 'cat'
for file_name in file_names:
cat_command += ' ' + os.path.join(self.cache_folder,
file_name)
cat_command += ' > ' + zip_path
os.system(cat_command)
if os.environ.get('DATASET') != 'full' or is_model:
download(data_urls[0], self.int8_download, data_md5s[0])
file_name = data_urls[0].split('/')[-1]
zip_path = os.path.join(self.cache_folder, file_name)
print('Data is downloaded at {0}'.format(zip_path))
self.cache_unzipping(data_cache_folder, zip_path)
return data_cache_folder
def download_model(self):
pass
def run_program(self,
model_path,
batch_size,
infer_iterations,
is_quantized_model=False):
image_shape = [3, 224, 224]
config = paddle.inference.Config(model_path)
config.disable_gpu()
config.enable_mkldnn()
config.switch_ir_optim()
config.set_cpu_math_library_num_threads(1)
config.disable_glog_info()
if is_quantized_model:
calibration_file_path = os.path.join(model_path,
'calibration_table.txt')
config.set_calibration_file_path(calibration_file_path)
config.enable_mkldnn_int8()
predictor = paddle.inference.create_predictor(config)
input_names = predictor.get_input_names()
image_tensor = predictor.get_input_handle(input_names[0])
label_tensor = predictor.get_input_handle(input_names[1])
output_names = predictor.get_output_names()
acc_tensor = predictor.get_output_handle("accuracy_0.tmp_0")
val_reader = paddle.batch(val(), batch_size)
iterations = infer_iterations
test_info = []
cnt = 0
periods = []
for batch_id, data in enumerate(val_reader()):
image = np.array([x[0].reshape(image_shape)
for x in data]).astype("float32")
label = np.array([x[1] for x in data]).astype("int64")
label = label.reshape([-1, 1])
t1 = time.time()
image_tensor.copy_from_cpu(image)
label_tensor.copy_from_cpu(label)
predictor.run()
acc1 = acc_tensor.copy_to_cpu()
t2 = time.time()
period = t2 - t1
periods.append(period)
test_info.append(np.mean(acc1) * len(data))
cnt += len(data)
if (batch_id + 1) % 100 == 0:
print("{0} images,".format(batch_id + 1))
sys.stdout.flush()
if (batch_id + 1) == iterations:
break
throughput = cnt / np.sum(periods)
latency = np.average(periods)
acc1 = np.sum(test_info) / cnt
return (throughput, latency, acc1)
def generate_quantized_model(self,
model_path,
quantizable_op_type,
algo="KL",
round_type="round",
is_full_quantize=False,
is_use_cache_file=False,
is_optimize_model=False,
onnx_format=False):
try:
os.system("mkdir " + self.int8_model)
except Exception as e:
print("Failed to create {} due to {}".format(
self.int8_model, str(e)))
sys.exit(-1)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.global_scope()
val_reader = val()
ptq = PostTrainingQuantization(executor=exe,
sample_generator=val_reader,
model_dir=model_path,
algo=algo,
quantizable_op_type=quantizable_op_type,
round_type=round_type,
is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model,
onnx_format=onnx_format,
is_use_cache_file=is_use_cache_file)
ptq.quantize()
ptq.save_quantized_model(self.int8_model)
if onnx_format:
try:
collect_dict = ptq._calibration_scales
save_quant_table_path = os.path.join(self.int8_model,
'calibration_table.txt')
with open(save_quant_table_path, 'w') as txt_file:
for tensor_name in collect_dict.keys():
write_line = '{} {}'.format(
tensor_name,
collect_dict[tensor_name]['scale']) + '\n'
txt_file.write(write_line)
print(
"Quantization clip ranges of tensors is save in: {}".format(
save_quant_table_path))
except:
print(
"Unable to generate `calibration_table.txt`, please update PaddlePaddle >= 2.3.3"
)
def run_test(self,
model,
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=True):
infer_iterations = self.infer_iterations
batch_size = self.batch_size
sample_iterations = self.sample_iterations
model_cache_folder = self.download_data(data_urls, data_md5s, model)
print("Start FP32 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size))
(fp32_throughput, fp32_latency, fp32_acc1) = self.run_program(
os.path.join(model_cache_folder, "model"), batch_size,
infer_iterations)
print("Start INT8 post training quantization for {0} on {1} images ...".
format(model, sample_iterations * batch_size))
self.generate_quantized_model(os.path.join(model_cache_folder, "model"),
quantizable_op_type, algo, round_type,
is_full_quantize, is_use_cache_file,
is_optimize_model, onnx_format)
print("Start INT8 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size))
(int8_throughput, int8_latency,
int8_acc1) = self.run_program(self.int8_model,
batch_size,
infer_iterations,
is_quantized_model=True)
print("---Post training quantization of {} method---".format(algo))
print(
"FP32 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}."
.format(model, batch_size, fp32_throughput, fp32_latency,
fp32_acc1))
print(
"INT8 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}.\n"
.format(model, batch_size, int8_throughput, int8_latency,
int8_acc1))
sys.stdout.flush()
delta_value = int8_latency - fp32_latency
self.assertLess(delta_value, diff_threshold)
class TestMKLDNNInt8ForMobilenetv1AvgONNXFormat(TestPostTrainingQuantization):
def test_onnx_format_avg_mobilenetv1(self):
model = "MobileNet-V1"
algo = "avg"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
diff_threshold = 0
self.run_test(model,
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=True)
class TestMKLDNNInt8ForMobilenetv1Avg(TestPostTrainingQuantization):
def test_avg_mobilenetv1(self):
model = "MobileNet-V1"
algo = "avg"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
diff_threshold = 0
self.run_test(model,
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=False)
class TestMKLDNNInt8ForMobilenetv1AbsMaxONNXFormat(TestPostTrainingQuantization
):
def test_onnx_format_abs_max_mobilenetv1(self):
model = "MobileNet-V1"
algo = "abs_max"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
# The accuracy diff of post-training quantization (abs_max) maybe bigger
diff_threshold = 0
self.run_test(model,
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=True)
class TestMKLDNNInt8ForMobilenetv1AbsMax(TestPostTrainingQuantization):
def test_abs_max_mobilenetv1(self):
model = "MobileNet-V1"
algo = "abs_max"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
# The accuracy diff of post-training quantization (abs_max) maybe bigger
diff_threshold = 0
self.run_test(model,
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=False)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册