未验证 提交 561fd8c8 编写于 作者: Y yeliang2258 提交者: GitHub

Fix quantize model deploy bugs when using MKLDNN (#45920)

* fix immutable op quantize bugs

* fix

* fix build bug

* fix test

* notest,test=inference

* fix ppyoloe acc drop bugs

* fix test

* fix test

* add test

* fix

* fix

* fix test

* fix refined name bug

* fix test

* bias fix

* fix matmul weight dequant bug

* re-ci

* fix tester

* fix test

* fix tester

* update weight dequantize func

* update code

* update test for converage

* update test

* update cmake

* update cmakelist

* update code

* rerun ci

* remove useless code
上级 910e1b6a
...@@ -67,7 +67,7 @@ std::vector<float> ComputePropagateScalesMkldnnPass::GetScales( ...@@ -67,7 +67,7 @@ std::vector<float> ComputePropagateScalesMkldnnPass::GetScales(
for (int i = 0; i < columns; i++) { for (int i = 0; i < columns; i++) {
float max_value = FLT_MIN; float max_value = FLT_MIN;
for (int j = 0; j < rows; j++) { for (int j = 0; j < rows; j++) {
max_value = std::max(max_value, std::abs(data[i + j * columns])); max_value = std::max(max_value, std::abs(data[j + i * rows]));
} }
max_value = 1.0 / max_value; max_value = 1.0 / max_value;
if (std::isinf(max_value) || std::isnan(max_value)) { if (std::isinf(max_value) || std::isnan(max_value)) {
......
...@@ -411,7 +411,16 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, ...@@ -411,7 +411,16 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
auto filter_scale_tensor = GetScaleTensorForNode(conv_filter); auto filter_scale_tensor = GetScaleTensorForNode(conv_filter);
EigenVectorArrayMap eigen_tensor{filter_scale_tensor.data<double>(), EigenVectorArrayMap eigen_tensor{filter_scale_tensor.data<double>(),
filter_scale_tensor.numel()}; filter_scale_tensor.numel()};
eigen_tensor *= static_cast<double>(S8_MAX);
// If the scale value of a weight is already multiplied by S8_MAX, it does
// not need to be multiplied again
if (std::find(change_weight_->begin(),
change_weight_->end(),
conv_filter->Name()) == change_weight_->end()) {
eigen_tensor *= static_cast<double>(S8_MAX);
change_weight_->push_back(conv_filter->Name());
}
std::vector<float> filter_scale{ std::vector<float> filter_scale{
filter_scale_tensor.data<double>(), filter_scale_tensor.data<double>(),
filter_scale_tensor.data<double>() + filter_scale_tensor.numel()}; filter_scale_tensor.data<double>() + filter_scale_tensor.numel()};
...@@ -693,6 +702,13 @@ void CPUQuantizePass::QuantizeImmutable(Graph* graph, ...@@ -693,6 +702,13 @@ void CPUQuantizePass::QuantizeImmutable(Graph* graph,
return; return;
} }
// skip if the dtype of immutable_in is not float32
auto dtype = immutable_in->Var()->GetDataType();
if (dtype != proto::VarType::FP32) {
MarkAndLogCannotQuantizeOp(immutable_op, "The input dtype is not float.");
return;
}
if (!AreScalesPresentForNodes({immutable_out})) { if (!AreScalesPresentForNodes({immutable_out})) {
MarkAndLogCannotQuantizeOp(immutable_op, MarkAndLogCannotQuantizeOp(immutable_op,
"No scale available for the operator"); "No scale available for the operator");
...@@ -1164,7 +1180,6 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -1164,7 +1180,6 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeImmutable(graph, "reshape2", "X"); QuantizeImmutable(graph, "reshape2", "X");
QuantizeImmutable(graph, "transpose2", "X"); QuantizeImmutable(graph, "transpose2", "X");
QuantizeImmutable(graph, "slice", "Input"); QuantizeImmutable(graph, "slice", "Input");
QuantizeImmutable(graph, "shape", "Input");
QuantizeImmutable(graph, "nearest_interp", "X"); QuantizeImmutable(graph, "nearest_interp", "X");
QuantizeImmutable(graph, "nearest_interp_v2", "X"); QuantizeImmutable(graph, "nearest_interp_v2", "X");
QuantizeElementwise(graph, "elementwise_add"); QuantizeElementwise(graph, "elementwise_add");
......
...@@ -110,6 +110,11 @@ class CPUQuantizePass : public FusePassBase { ...@@ -110,6 +110,11 @@ class CPUQuantizePass : public FusePassBase {
VarQuantScale string_pair_map = {}; VarQuantScale string_pair_map = {};
VarQuantScale* const var_quant_scales_ = &string_pair_map; VarQuantScale* const var_quant_scales_ = &string_pair_map;
// Save the scale values of which weights have been processed to avoid
// secondary processing
std::vector<std::string> change_weight = {};
std::vector<std::string>* const change_weight_ = &change_weight;
void GetQuantInfo(Graph* graph) const; void GetQuantInfo(Graph* graph) const;
}; };
......
...@@ -66,7 +66,7 @@ void SetOp(ProgramDesc* prog, ...@@ -66,7 +66,7 @@ void SetOp(ProgramDesc* prog,
type == "nearest_interp" || type == "nearest_interp_v2") { type == "nearest_interp" || type == "nearest_interp_v2") {
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
} else if (type == "slice" || type == "shape") { } else if (type == "slice") {
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
} else if (type == "dropout") { } else if (type == "dropout") {
...@@ -467,7 +467,7 @@ static const std::initializer_list<std::string> variable_names_immutable_ops = { ...@@ -467,7 +467,7 @@ static const std::initializer_list<std::string> variable_names_immutable_ops = {
void TestImmutableOp(const std::string tested_op) { void TestImmutableOp(const std::string tested_op) {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : variable_names_immutable_ops) { for (auto& v : variable_names_immutable_ops) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32);
} }
SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true); SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true);
SetOp(&prog, tested_op, tested_op, {"b"}, {"c"}, true, "int8"); SetOp(&prog, tested_op, tested_op, {"b"}, {"c"}, true, "int8");
...@@ -520,7 +520,7 @@ void TestImmutableOpBetweenNonQuantizedOp(const std::string tested_op) { ...@@ -520,7 +520,7 @@ void TestImmutableOpBetweenNonQuantizedOp(const std::string tested_op) {
void TestImmutableOpWithManyOutputs(const std::string tested_op) { void TestImmutableOpWithManyOutputs(const std::string tested_op) {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : variable_names_immutable_ops) { for (auto& v : variable_names_immutable_ops) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32);
} }
SetOp(&prog, "dropout", "Dropout1", {"a"}, {"b"}, true, "float32"); SetOp(&prog, "dropout", "Dropout1", {"a"}, {"b"}, true, "float32");
...@@ -556,12 +556,8 @@ void TestImmutableOpWithManyOutputs(const std::string tested_op) { ...@@ -556,12 +556,8 @@ void TestImmutableOpWithManyOutputs(const std::string tested_op) {
SCALE * S8_MAX); SCALE * S8_MAX);
} }
const std::vector<std::string> immutables = {"reshape2", const std::vector<std::string> immutables = {
"transpose2", "reshape2", "transpose2", "slice", "nearest_interp", "nearest_interp_v2"};
"slice",
"shape",
"nearest_interp",
"nearest_interp_v2"};
class TestImmutables : public testing::TestWithParam<std::string> {}; class TestImmutables : public testing::TestWithParam<std::string> {};
......
...@@ -52,36 +52,25 @@ bool HasBias(ir::Node* conv_op) { ...@@ -52,36 +52,25 @@ bool HasBias(ir::Node* conv_op) {
conv_op->Op()->Input("Bias").size() > 0; conv_op->Op()->Input("Bias").size() > 0;
} }
bool ShouldSkipConv(ir::Node* conv_op, Scope* scope, ir::Node* conv_filter) {
if (!platform::HasOpINT8DataType(conv_op->Op())) {
VLOG(4) << "Skipping non-int8 convolution (id: " << conv_op->id() << ").";
return true;
}
auto filter_var = scope->GetVar(conv_filter->Name());
if (filter_var->Get<LoDTensor>().dtype() != phi::DataType::FLOAT32) {
VLOG(4) << "Skipping convolution (id: " << conv_op->id()
<< ") because it's a bug that it is detected again.";
return true;
}
VLOG(4) << "Not skipping convolution (id: " << conv_op->id() << ")";
return false;
}
template <typename T> template <typename T>
void QuantizeConvInput(Scope* scope, void QuantizeConvInput(Scope* scope,
ir::Graph* g, ir::Graph* g,
ir::Node* conv_op, ir::Node* conv_op,
const std::string& input_name, const std::string& input_name,
const std::string& scales_attr_name) { const std::string& scales_attr_name) {
const auto scales = auto var = scope->GetVar(input_name);
conv_op->Op()->GetAttrIfExists<std::vector<float>>(scales_attr_name); if (var->Get<LoDTensor>().dtype() != phi::DataType::FLOAT32) {
VLOG(0) << "Skipping convolution filter: " << input_name
auto* tensor = scope->GetVar(input_name)->GetMutable<LoDTensor>(); << " because it is detected again.";
QuantizeParams<T>(tensor, scales); conv_op->Op()->SetAttr(scales_attr_name, std::vector<float>(1, 1));
} else {
conv_op->Op()->SetAttr(scales_attr_name, std::vector<float>(1, 1)); const auto scales =
conv_op->Op()->GetAttrIfExists<std::vector<float>>(scales_attr_name);
auto* tensor = scope->GetVar(input_name)->GetMutable<LoDTensor>();
QuantizeParams<T>(tensor, scales);
conv_op->Op()->SetAttr(scales_attr_name, std::vector<float>(1, 1));
}
} }
} // namespace } // namespace
...@@ -151,7 +140,8 @@ void ParamsQuantizationMkldnnPass::QuantizeConv(ir::Graph* graph, ...@@ -151,7 +140,8 @@ void ParamsQuantizationMkldnnPass::QuantizeConv(ir::Graph* graph,
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
if (ShouldSkipConv(conv_op, scope, conv_filter)) { // If not a quantized OP
if (!platform::HasOpINT8DataType(conv_op->Op())) {
return; return;
} }
......
...@@ -89,8 +89,14 @@ struct ProgramStrategy { ...@@ -89,8 +89,14 @@ struct ProgramStrategy {
virtual void CheckOp(const OpDesc& op) const = 0; virtual void CheckOp(const OpDesc& op) const = 0;
VarDesc* AddInput(OpDesc* op, std::string input_name, const Data& data) { VarDesc* AddInput(OpDesc* op,
const std::string var_name = input_name + "_var"; std::string input_name,
const Data& data,
const std::string user_var_name = "") {
std::string var_name = user_var_name;
if (var_name.empty()) {
var_name = input_name + "_var";
}
op->SetInput(input_name, {var_name}); op->SetInput(input_name, {var_name});
auto var = program.MutableBlock(0)->Var(var_name); auto var = program.MutableBlock(0)->Var(var_name);
var->SetShape(data.getShape()); var->SetShape(data.getShape());
...@@ -98,8 +104,14 @@ struct ProgramStrategy { ...@@ -98,8 +104,14 @@ struct ProgramStrategy {
return var; return var;
} }
void AddOutput(OpDesc* op, std::string output_name, const Data& data) { void AddOutput(OpDesc* op,
const std::string var_name = output_name + "_var"; std::string output_name,
const Data& data,
const std::string user_var_name = "") {
std::string var_name = user_var_name;
if (var_name.empty()) {
var_name = output_name + "_var";
}
op->SetOutput(output_name, {var_name}); op->SetOutput(output_name, {var_name});
program.MutableBlock(0)->Var(var_name); program.MutableBlock(0)->Var(var_name);
test_scope.CreateTensor(var_name, data); test_scope.CreateTensor(var_name, data);
...@@ -117,21 +129,23 @@ struct ConvProgramStrategy : public ProgramStrategy { ...@@ -117,21 +129,23 @@ struct ConvProgramStrategy : public ProgramStrategy {
std::vector<float>&& scale_weights, std::vector<float>&& scale_weights,
int groups = 1, int groups = 1,
Data&& bias = Data(), Data&& bias = Data(),
std::vector<float>&& scale_bias = {}) std::vector<float>&& scale_bias = {},
bool share_weight = false)
: input(std::move(input)), : input(std::move(input)),
filter(std::move(filter)), filter(std::move(filter)),
output(std::move(output)), output(std::move(output)),
scale_weights(std::move(scale_weights)), scale_weights(std::move(scale_weights)),
groups(std::move(groups)), groups(std::move(groups)),
bias(std::move(bias)), bias(std::move(bias)),
scale_bias(std::move(scale_bias)) {} scale_bias(std::move(scale_bias)),
share_weight(std::move(share_weight)) {}
protected: protected:
OpDesc* CreateBasicConvOp() { OpDesc* CreateBasicConvOp(const std::string conv_name = "Conv1") {
auto op = program.MutableBlock(0)->AppendOp(); auto op = program.MutableBlock(0)->AppendOp();
op->SetType("conv2d"); op->SetType("conv2d");
op->SetAttr("use_mkldnn", true); op->SetAttr("use_mkldnn", true);
op->SetAttr("name", std::string{"Conv1"}); op->SetAttr("name", conv_name);
op->SetAttr("mkldnn_data_type", std::string{"int8"}); op->SetAttr("mkldnn_data_type", std::string{"int8"});
op->SetAttr("data_format", std::string{"NCHW"}); op->SetAttr("data_format", std::string{"NCHW"});
op->SetAttr("dilations", std::vector<int>({1, 1})); op->SetAttr("dilations", std::vector<int>({1, 1}));
...@@ -155,6 +169,20 @@ struct ConvProgramStrategy : public ProgramStrategy { ...@@ -155,6 +169,20 @@ struct ConvProgramStrategy : public ProgramStrategy {
AddInput(op, "Bias", bias); AddInput(op, "Bias", bias);
op->SetAttr("Bias_scales", scale_bias); op->SetAttr("Bias_scales", scale_bias);
} }
if (share_weight) {
OpDesc* op2 = CreateBasicConvOp("Conv2");
AddInput(op2, "Input", input);
AddInput(op2, "Filter", filter)->SetPersistable(true);
AddOutput(op2, "Output", output, "output2");
op2->SetAttr("Scale_weights", scale_weights);
op2->SetAttr("Scale_in", 1.0f);
op2->SetAttr("groups", groups);
if (HasBias()) {
AddInput(op2, "Bias", bias, "Bias2");
op2->SetAttr("Bias_scales", scale_bias);
}
}
} }
void CheckOp(const OpDesc& op) const override { void CheckOp(const OpDesc& op) const override {
...@@ -210,9 +238,9 @@ struct ConvProgramStrategy : public ProgramStrategy { ...@@ -210,9 +238,9 @@ struct ConvProgramStrategy : public ProgramStrategy {
const Data output; const Data output;
const std::vector<float> scale_weights; const std::vector<float> scale_weights;
const int groups; const int groups;
const Data bias; const Data bias;
const std::vector<float> scale_bias; const std::vector<float> scale_bias;
const bool share_weight;
}; };
struct ParamsQuantizationMkldnnPassTestFixture : public ::testing::Test { struct ParamsQuantizationMkldnnPassTestFixture : public ::testing::Test {
...@@ -340,6 +368,19 @@ TEST_F(ParamsQuantizationMkldnnPassTestFixture, conv_with_bias_2g2o2i1h1w) { ...@@ -340,6 +368,19 @@ TEST_F(ParamsQuantizationMkldnnPassTestFixture, conv_with_bias_2g2o2i1h1w) {
RunPassTest(std::move(program)); RunPassTest(std::move(program));
} }
TEST_F(ParamsQuantizationMkldnnPassTestFixture, conv_with_bias_2g2o2i1h1ws) {
auto program = std::make_unique<ConvProgramStrategy>(
GenericInput(),
Data({2, 2, 2, 1, 1}, {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f}),
GenericOutput(),
std::vector<float>{2.f, 2.f, 4.f, 4.f},
2,
Data({2, 2, 1, 1, 1}, {1.5f, 1.5f, 1.5f, 1.5f}),
std::vector<float>{2.f, 2.f, 4.f, 4.f},
true);
RunPassTest(std::move(program));
}
} // namespace } // namespace
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -109,27 +109,34 @@ void QuantDequantMkldnnPass::CollectWeightScalesInfoFromONNXFormatDequantize( ...@@ -109,27 +109,34 @@ void QuantDequantMkldnnPass::CollectWeightScalesInfoFromONNXFormatDequantize(
if (op_node->Name() == "dequantize_linear") { if (op_node->Name() == "dequantize_linear") {
auto* op_desc = op_node->Op(); auto* op_desc = op_node->Op();
auto scale_name = op_desc->Input("Scale")[0];
auto* var = scope->FindVar(scale_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"The Scales variable [%s] of dequantize op is not found.", var));
auto* scale_tensor = var->GetMutable<LoDTensor>();
auto* scale_data = scale_tensor->data<float>();
auto x_var_name = op_desc->Input("X")[0]; auto x_var_name = op_desc->Input("X")[0];
auto* weight_var = scope->FindVar(x_var_name); auto* weight_var = scope->FindVar(x_var_name);
if (!weight_var) { if (!weight_var) {
auto out_var_name = op_desc->Output("Y")[0]; auto out_var_name = op_desc->Output("Y")[0];
if (var_quant_scales->count(x_var_name) && float scale = 1.0 / scale_data[0];
!var_quant_scales->count(out_var_name)) { if (std::isinf(scale) || std::isnan(scale)) {
std::vector<float> scale_v = var_quant_scales->at(x_var_name); scale = 0.0;
}
std::vector<float> scale_v = {scale};
if (!var_quant_scales->count(out_var_name)) {
var_quant_scales->insert(std::make_pair(out_var_name, scale_v)); var_quant_scales->insert(std::make_pair(out_var_name, scale_v));
} }
if (!var_quant_scales->count(x_var_name)) {
var_quant_scales->insert(std::make_pair(x_var_name, scale_v));
}
} else { } else {
*onnx_format_quantize_model = true; *onnx_format_quantize_model = true;
auto scale_name = op_desc->Input("Scale")[0];
auto* var = scope->FindVar(scale_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"The Scales variable [%s] of dequantize op is not found.",
var));
auto* scale_tensor = var->GetMutable<LoDTensor>();
auto* scale_data = scale_tensor->data<float>();
std::vector<float> thresholds(scale_data, std::vector<float> thresholds(scale_data,
scale_data + scale_tensor->numel()); scale_data + scale_tensor->numel());
weight_thresholds->insert(std::make_pair(x_var_name, thresholds)); weight_thresholds->insert(std::make_pair(x_var_name, thresholds));
...@@ -182,7 +189,7 @@ void QuantDequantMkldnnPass::CollectInputScalesFromQuantize( ...@@ -182,7 +189,7 @@ void QuantDequantMkldnnPass::CollectInputScalesFromQuantize(
auto* scale_data = scale_tensor->data<float>(); auto* scale_data = scale_tensor->data<float>();
float scale = 1.0 / scale_data[0]; float scale = 1.0 / scale_data[0];
if (std::isinf(scale) || std::isnan(scale)) { if (std::isinf(scale) || std::isnan(scale)) {
scale = 0.0; continue;
} }
if (!var_quant_scales->count(x_var_name)) { if (!var_quant_scales->count(x_var_name)) {
...@@ -520,12 +527,10 @@ void QuantDequantMkldnnPass::ConvertFromINT8ToFP32( ...@@ -520,12 +527,10 @@ void QuantDequantMkldnnPass::ConvertFromINT8ToFP32(
int step_c = step_n / size; int step_c = step_n / size;
for (int i = 0; i < weight_dims[0]; i++) { for (int i = 0; i < weight_dims[0]; i++) {
int begin_n = i * step_n; int begin_n = i * step_n;
for (int j = begin_n; j < begin_n + step_n; j++) { for (int j = 0; j < size; j++) {
for (int k = 0; k < size; k++) { int begin_c = begin_n + j * step_c;
int begin_c = k * step_c; for (int k = 0; k < step_c; k++) {
for (int m = begin_c; m < begin_c + step_c; m++) { weight_data[begin_c + k] *= scales[j];
weight_data[m] *= scales[k];
}
} }
} }
} }
...@@ -588,7 +593,8 @@ void QuantDequantMkldnnPass::DequantizeOpWeightsFromONNXFormat( ...@@ -588,7 +593,8 @@ void QuantDequantMkldnnPass::DequantizeOpWeightsFromONNXFormat(
Scope* scope, Scope* scope,
const std::string& weight_name, const std::string& weight_name,
const std::unordered_map<std::string, std::vector<float>>& const std::unordered_map<std::string, std::vector<float>>&
weight_thresholds) const { weight_thresholds,
std::vector<std::string>* dequantized_weights_names) const {
auto* op_desc = op_node->Op(); auto* op_desc = op_node->Op();
std::string weight_var_name = op_desc->Input(weight_name)[0]; std::string weight_var_name = op_desc->Input(weight_name)[0];
...@@ -596,6 +602,13 @@ void QuantDequantMkldnnPass::DequantizeOpWeightsFromONNXFormat( ...@@ -596,6 +602,13 @@ void QuantDequantMkldnnPass::DequantizeOpWeightsFromONNXFormat(
auto iter = weight_thresholds.find(weight_var_name); auto iter = weight_thresholds.find(weight_var_name);
if (iter != weight_thresholds.end()) { if (iter != weight_thresholds.end()) {
scales = iter->second; scales = iter->second;
auto name_iter = std::find(dequantized_weights_names->begin(),
dequantized_weights_names->end(),
weight_var_name);
// Has been dequantized
if (name_iter != dequantized_weights_names->end()) {
return;
}
} else { } else {
if (!IsInt8Weight(op_node, scope, weight_name)) { if (!IsInt8Weight(op_node, scope, weight_name)) {
return; return;
...@@ -605,7 +618,7 @@ void QuantDequantMkldnnPass::DequantizeOpWeightsFromONNXFormat( ...@@ -605,7 +618,7 @@ void QuantDequantMkldnnPass::DequantizeOpWeightsFromONNXFormat(
"the model is correct.", "the model is correct.",
weight_var_name)); weight_var_name));
} }
dequantized_weights_names->push_back(weight_var_name);
auto* var = scope->FindVar(weight_var_name); auto* var = scope->FindVar(weight_var_name);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
var, var,
...@@ -634,14 +647,17 @@ void QuantDequantMkldnnPass::DequantizeWeights( ...@@ -634,14 +647,17 @@ void QuantDequantMkldnnPass::DequantizeWeights(
<< "No need to dequantize weights because weight_thresholds is empty."; << "No need to dequantize weights because weight_thresholds is empty.";
return; return;
} }
std::vector<std::string> dequantized_weights_names;
for (auto* op_node : for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) { ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp()) continue; if (!op_node->IsOp()) continue;
if (op_node->Name() == "conv2d" || op_node->Name() == "depthwise_conv2d") { if (op_node->Name() == "conv2d" || op_node->Name() == "depthwise_conv2d") {
if (onnx_format_quantize_model) { if (onnx_format_quantize_model) {
DequantizeOpWeightsFromONNXFormat( DequantizeOpWeightsFromONNXFormat(op_node,
op_node, scope, "Filter", weight_thresholds); scope,
"Filter",
weight_thresholds,
&dequantized_weights_names);
} else if (IsInt8Weight(op_node, scope, "Filter")) { } else if (IsInt8Weight(op_node, scope, "Filter")) {
DequantizeOpWeights( DequantizeOpWeights(
op_node, scope, "Filter", "Output", weight_thresholds); op_node, scope, "Filter", "Output", weight_thresholds);
...@@ -650,7 +666,7 @@ void QuantDequantMkldnnPass::DequantizeWeights( ...@@ -650,7 +666,7 @@ void QuantDequantMkldnnPass::DequantizeWeights(
op_node->Name() == "matmul_v2") { op_node->Name() == "matmul_v2") {
if (onnx_format_quantize_model) { if (onnx_format_quantize_model) {
DequantizeOpWeightsFromONNXFormat( DequantizeOpWeightsFromONNXFormat(
op_node, scope, "Y", weight_thresholds); op_node, scope, "Y", weight_thresholds, &dequantized_weights_names);
} else if (IsInt8Weight(op_node, scope, "Y")) { } else if (IsInt8Weight(op_node, scope, "Y")) {
DequantizeOpWeights(op_node, scope, "Y", "Out", weight_thresholds); DequantizeOpWeights(op_node, scope, "Y", "Out", weight_thresholds);
} }
......
...@@ -125,7 +125,8 @@ class QuantDequantMkldnnPass : public FusePassBase { ...@@ -125,7 +125,8 @@ class QuantDequantMkldnnPass : public FusePassBase {
Scope* scope, Scope* scope,
const std::string& weight_name, const std::string& weight_name,
const std::unordered_map<std::string, std::vector<float>>& const std::unordered_map<std::string, std::vector<float>>&
weight_thresholds) const; weight_thresholds,
std::vector<std::string>* dequantized_weights_names) const;
void DequantizeWeights( void DequantizeWeights(
ir::Graph* graph, ir::Graph* graph,
......
...@@ -17,7 +17,6 @@ import time ...@@ -17,7 +17,6 @@ import time
import sys import sys
import random import random
import functools import functools
import tempfile
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import paddle import paddle
...@@ -149,13 +148,13 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -149,13 +148,13 @@ class TestPostTrainingQuantization(unittest.TestCase):
self.infer_iterations = 50000 if os.environ.get( self.infer_iterations = 50000 if os.environ.get(
'DATASET') == 'full' else 2 'DATASET') == 'full' else 2
self.root_path = tempfile.TemporaryDirectory() self.int8_model = "post_training_quantization"
self.int8_model = os.path.join(self.root_path.name,
"post_training_quantization")
print("self.int8_model: ", self.int8_model) print("self.int8_model: ", self.int8_model)
def tearDown(self): def tearDown(self):
self.root_path.cleanup() cmd = 'rm -rf post_training_quantization'
os.system(cmd)
pass
def cache_unzipping(self, target_folder, zip_path): def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder): if not os.path.exists(target_folder):
...@@ -262,16 +261,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -262,16 +261,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_use_cache_file=False, is_use_cache_file=False,
is_optimize_model=False, is_optimize_model=False,
onnx_format=False): onnx_format=False):
try:
os.system("mkdir " + self.int8_model)
except Exception as e:
print("Failed to create {} due to {}".format(
self.int8_model, str(e)))
sys.exit(-1)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
scope = fluid.global_scope()
val_reader = val() val_reader = val()
ptq = PostTrainingQuantization(executor=exe, ptq = PostTrainingQuantization(executor=exe,
...@@ -305,12 +296,6 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -305,12 +296,6 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_cache_folder = self.download_data(data_urls, data_md5s, model) model_cache_folder = self.download_data(data_urls, data_md5s, model)
print("Start FP32 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size))
(fp32_throughput, fp32_latency, fp32_acc1) = self.run_program(
os.path.join(model_cache_folder, "model"), batch_size,
infer_iterations)
print("Start INT8 post training quantization for {0} on {1} images ...". print("Start INT8 post training quantization for {0} on {1} images ...".
format(model, sample_iterations * batch_size)) format(model, sample_iterations * batch_size))
self.generate_quantized_model(os.path.join(model_cache_folder, "model"), self.generate_quantized_model(os.path.join(model_cache_folder, "model"),
...@@ -318,6 +303,12 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -318,6 +303,12 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_full_quantize, is_use_cache_file, is_full_quantize, is_use_cache_file,
is_optimize_model, onnx_format) is_optimize_model, onnx_format)
print("Start FP32 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size))
(fp32_throughput, fp32_latency, fp32_acc1) = self.run_program(
os.path.join(model_cache_folder, "model"), batch_size,
infer_iterations)
print("Start INT8 inference for {0} on {1} images ...".format( print("Start INT8 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size)) model, infer_iterations * batch_size))
(int8_throughput, int8_latency, (int8_throughput, int8_latency,
...@@ -341,10 +332,10 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -341,10 +332,10 @@ class TestPostTrainingQuantization(unittest.TestCase):
self.assertLess(delta_value, diff_threshold) self.assertLess(delta_value, diff_threshold)
class TestMKLDNNInt8ForMobilenetv1AvgONNXFormat(TestPostTrainingQuantization): class TestMKLDNNInt8ForResnet50AvgONNXFormat(TestPostTrainingQuantization):
def test_onnx_format_avg_mobilenetv1(self): def test_onnx_format_avg_resnet50(self):
model = "MobileNet-V1" model = "resnet50"
algo = "avg" algo = "avg"
round_type = "round" round_type = "round"
data_urls = [ data_urls = [
...@@ -373,66 +364,5 @@ class TestMKLDNNInt8ForMobilenetv1AvgONNXFormat(TestPostTrainingQuantization): ...@@ -373,66 +364,5 @@ class TestMKLDNNInt8ForMobilenetv1AvgONNXFormat(TestPostTrainingQuantization):
onnx_format=True) onnx_format=True)
class TestMKLDNNInt8ForMobilenetv1Avg(TestPostTrainingQuantization):
def test_avg_mobilenetv1(self):
model = "MobileNet-V1"
algo = "avg"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
diff_threshold = 0
self.run_test(model,
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=False)
class TestMKLDNNInt8ForMobilenetv1AbsMax(TestPostTrainingQuantization):
def test_abs_max_mobilenetv1(self):
model = "MobileNet-V1"
algo = "abs_max"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
# The accuracy diff of post-training quantization (abs_max) maybe bigger
diff_threshold = 0
self.run_test(model,
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=False)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册