From e68da187fd59a476828b723510139ac9af85cc35 Mon Sep 17 00:00:00 2001 From: baoachun <962571062@qq.com> Date: Sun, 10 Apr 2022 10:28:17 +0800 Subject: [PATCH] add mkldnn int8 pass [step1] (#41579) * add mkldnn int8 pass * add mkldnn int8 pass * update pass --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../framework/ir/mkldnn/mkldnn_pass_util.h | 77 +++ .../ir/mkldnn/quant_dequant_mkldnn_pass.cc | 582 ++++++++++++++++++ .../ir/mkldnn/quant_dequant_mkldnn_pass.h | 91 +++ 4 files changed, 751 insertions(+) create mode 100644 paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h create mode 100644 paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc create mode 100644 paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 16a95b2ccf..4ee0b08375 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -140,6 +140,7 @@ if(WITH_MKLDNN) pass_library(batch_norm_act_fuse_pass inference DIR mkldnn) pass_library(multi_gru_fuse_pass inference DIR mkldnn) pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn) + pass_library(quant_dequant_mkldnn_pass inference DIR mkldnn) endif() if(WITH_IPU) diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h b/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h new file mode 100644 index 0000000000..505bb2739e --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h @@ -0,0 +1,77 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/ir/graph_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +static void SaveInfoInTheFirstOp( + ir::Graph* graph, const std::string& flag, const std::string& key_suffix, + const std::unordered_map>& info_map) { + VLOG(3) << "save variables in the first op's attr"; + + const std::string suffix = "_" + key_suffix + "_" + flag; + for (auto* op_node : + ir::TopologyVarientSort(*graph, static_cast(0))) { + if (!op_node->IsOp() || op_node->Op()->Type() == "feed" || + op_node->Op()->Type() == "fetch") + continue; + + op_node->Op()->SetAttr(flag, true); + for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) { + op_node->Op()->SetAttr(iter->first + suffix, iter->second); + } + break; + } +} + +static void GetInfoFromTheFirstOp( + ir::Graph* graph, const std::string& flag, const std::string& key_suffix, + std::unordered_map>* info_map) { + VLOG(3) << "get variables from the first op's attr"; + + const std::string suffix = "_" + key_suffix + "_" + flag; + for (auto* op_node : + ir::TopologyVarientSort(*graph, static_cast(0))) { + if (!op_node->IsOp() || op_node->Op()->Type() == "feed" || + op_node->Op()->Type() == "fetch") + continue; + + auto* op_desc = op_node->Op(); + if (op_desc->GetAttrIfExists(flag)) { + op_desc->RemoveAttr(flag); + std::vector attr_names = op_desc->AttrNames(); + for (auto fake_name : attr_names) { + size_t pos = fake_name.find(suffix); + if (pos != std::string::npos) { + std::string name = fake_name.substr(0, pos); + auto scales_vector = + BOOST_GET_CONST(std::vector, op_desc->GetAttr(fake_name)); + info_map->insert(std::make_pair(name, scales_vector)); + op_desc->RemoveAttr(fake_name); + } + } + break; + } + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc new file mode 100644 index 0000000000..808d043a4b --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc @@ -0,0 +1,582 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.h" +#include +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { + +void QuantDequantMkldnnPass::MarkSkipQuantizedOps( + ir::Graph* graph, const std::unordered_set& skip_ops) const { + VLOG(3) << "mark skip quantized ops"; + for (auto* op_node : + ir::TopologyVarientSort(*graph, static_cast(0))) { + if (!op_node->IsOp()) continue; + + if (skip_ops.count(op_node->Name())) { + auto* op_desc = op_node->Op(); + if (!op_desc->HasAttr("quantization_type")) { + bool is_quantized_op = true; + for (auto* node_input : op_node->inputs) { + for (auto* node_input_input : node_input->inputs) { + if (!node_input_input->IsOp()) continue; + if (node_input_input->Name().find("quantize_dequantize") == + std::string::npos) { + is_quantized_op = false; + break; + } + } + if (!is_quantized_op) break; + } + + if (!is_quantized_op) { + op_node->Op()->SetAttr("skip_quant", 1); + } + } + } + } +} + +void QuantDequantMkldnnPass::MarkSkipQuantizedPool2d(ir::Graph* graph) const { + VLOG(3) << "mark avg pool2d as skip quantized op"; + for (auto* op_node : + ir::TopologyVarientSort(*graph, static_cast(0))) { + if (!op_node->IsOp()) continue; + + if (op_node->Name() == "pool2d") { + auto* op_desc = op_node->Op(); + auto pool_type = + BOOST_GET_CONST(std::string, op_desc->GetAttr("pooling_type")); + if (pool_type == "avg") { + op_node->Op()->SetAttr("skip_quant", 1); + } + } + } +} + +void QuantDequantMkldnnPass::CollectInfoFromFake( + ir::Graph* graph, Scope* scope, + const std::unordered_set& fake_dequantize_types, + std::unordered_map>* weight_thresholds) + const { + VLOG(3) << "gather weight_thresholds from fake dequantized ops"; + for (auto* op_node : + ir::TopologyVarientSort(*graph, static_cast(0))) { + if (!op_node->IsOp()) continue; + + if (fake_dequantize_types.count(op_node->Name())) { + auto* op_desc = op_node->Op(); + auto x_var_name = op_desc->Input("X")[0]; + + if (op_desc->HasAttr("max_range")) { + const float max_range = + BOOST_GET_CONST(float, op_desc->GetAttr("max_range")); + std::vector thresholds = {127 * 127 / max_range}; + weight_thresholds->insert(std::make_pair(x_var_name, thresholds)); + } else { + auto scale_name = op_desc->Input("Scales")[0]; + auto* var = scope->FindVar(scale_name); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::NotFound( + "The Scales variable [%s] of dequantize op is not found.", + var)); + + auto* scale_tensor = var->GetMutable(); + auto* scale_data = scale_tensor->data(); + std::vector thresholds{}; + for (int i = 0; i < scale_tensor->numel(); i++) { + thresholds.push_back(scale_data[i]); + } + weight_thresholds->insert(std::make_pair(x_var_name, thresholds)); + } + } + } +} + +void QuantDequantMkldnnPass::CollectInputScalesFromFake( + ir::Graph* graph, Scope* scope, + const std::unordered_set& fake_quantize_types, + std::unordered_map>* var_quant_scales) + const { + VLOG(3) << "gather input scales from fake quantized ops"; + for (auto* op_node : + ir::TopologyVarientSort(*graph, static_cast(0))) { + if (!op_node->IsOp()) continue; + + if (op_node->Name() == "fake_quantize_dequantize_moving_average_abs_max" || + fake_quantize_types.count(op_node->Name())) { + auto* op_desc = op_node->Op(); + const int bit_length = + BOOST_GET_CONST(int, op_desc->GetAttr("bit_length")); + PADDLE_ENFORCE_EQ(bit_length, 8, platform::errors::InvalidArgument( + "Unsupported number quantization " + "bits: %d, only 8 is supported now.", + bit_length)); + + auto x_var_name = op_desc->Input("X")[0]; + auto scale_name = op_desc->Input("InScale")[0]; + auto out_var_name = op_desc->Output("Out")[0]; + auto* var = scope->FindVar(scale_name); + PADDLE_ENFORCE_NOT_NULL( + var, + platform::errors::NotFound( + "The InScale variable [%s] of quantize op is not found.", var)); + + auto* scale_tensor = var->GetMutable(); + auto* scale_data = scale_tensor->data(); + float scale = 1.0 / scale_data[0]; + if (std::isinf(scale) || std::isnan(scale)) { + scale = 0.0; + } + + if (!var_quant_scales->count(x_var_name)) { + std::vector scale_v = {scale}; + var_quant_scales->insert(std::make_pair(x_var_name, scale_v)); + } + + if (!var_quant_scales->count(out_var_name)) { + std::vector scale_v = {scale}; + var_quant_scales->insert(std::make_pair(out_var_name, scale_v)); + } + } + } +} + +void QuantDequantMkldnnPass::CollectOutputScalesFromAttr( + ir::Graph* graph, + std::unordered_map>* var_quant_scales) + const { + VLOG(3) << "gather output scales from op's attr"; + for (auto* op_node : + ir::TopologyVarientSort(*graph, static_cast(0))) { + if (!op_node->IsOp()) continue; + + auto* op_desc = op_node->Op(); + if (op_desc->HasAttr("out_threshold")) { + const float attr_scale = + BOOST_GET_CONST(float, op_desc->GetAttr("out_threshold")); + if (attr_scale == 0.0) continue; + float scale = 1.0 / attr_scale; + std::vector scale_v = {scale}; + + auto var_name_map = op_desc->Outputs(); + for (auto iter = var_name_map.begin(); iter != var_name_map.end(); + ++iter) { + for (auto var_name : iter->second) { + var_quant_scales->insert(std::make_pair(var_name, scale_v)); + } + } + } + } +} + +void QuantDequantMkldnnPass::CollectFakeQuantizeOps( + ir::Graph* graph, Node* op_node, + std::unordered_set* nodes2rm) const { + auto* op_desc = op_node->Op(); + auto x_var_name = op_desc->Input("X")[0]; + auto in_scale_name = op_desc->Input("InScale")[0]; + auto out_var_name = op_desc->Output("Out")[0]; + auto out_scale_name = op_desc->Output("OutScale")[0]; + + Node* fake_quant_in = nullptr; + Node* fake_quant_in_scale = nullptr; + for (auto* node_input : op_node->inputs) { + if (node_input->Name() == x_var_name) { + fake_quant_in = node_input; + break; + } else if (node_input->Name() == in_scale_name) { + fake_quant_in_scale = node_input; + break; + } + } + + Node* fake_quant_out = nullptr; + Node* fake_quant_out_scale = nullptr; + for (auto* node_output : op_node->outputs) { + if (node_output->Name() == out_var_name) { + fake_quant_out = node_output; + break; + } else if (node_output->Name() == out_scale_name) { + fake_quant_out_scale = node_output; + break; + } + } + + PADDLE_ENFORCE_NOT_NULL( + fake_quant_in, + platform::errors::NotFound( + "The input var [%s] of quantize op is not found.", x_var_name)); + PADDLE_ENFORCE_NOT_NULL( + fake_quant_out, + platform::errors::NotFound( + "The output var [%s] of quantize op is not found.", out_var_name)); + + std::string input_act_name = fake_quant_in->Var()->Name(); + std::string output_act_name = fake_quant_out->Var()->Name(); + auto outlinks = fake_quant_out->outputs; + for (auto* next_node : outlinks) { + if (!next_node->IsOp()) continue; + next_node->Op()->RenameInput(output_act_name, input_act_name); + IR_NODE_LINK_TO(fake_quant_in, next_node); + } + + nodes2rm->insert(op_node); + nodes2rm->insert(fake_quant_in_scale); + nodes2rm->insert(fake_quant_out); + nodes2rm->insert(fake_quant_out_scale); +} + +void QuantDequantMkldnnPass::CollectFakeDequantizeOps( + ir::Graph* graph, Node* op_node, + std::unordered_set* nodes2rm) const { + auto* op_desc = op_node->Op(); + auto x_var_name = op_desc->Input("X")[0]; + auto out_var_name = op_desc->Output("Out")[0]; + + Node* fake_dequant_in = nullptr; + for (auto* node_input : op_node->inputs) { + if (node_input->Name() == x_var_name) { + fake_dequant_in = node_input; + break; + } + } + + Node* fake_dequant_out = nullptr; + for (auto* node_output : op_node->outputs) { + if (node_output->Name() == out_var_name) { + fake_dequant_out = node_output; + break; + } + } + + PADDLE_ENFORCE_NOT_NULL( + fake_dequant_in, + platform::errors::NotFound( + "The input var [%s] of dequantize op is not found.", x_var_name)); + PADDLE_ENFORCE_NOT_NULL( + fake_dequant_out, + platform::errors::NotFound( + "The output var [%s] of dequantize op is not found.", out_var_name)); + + std::string input_act_name = fake_dequant_in->Var()->Name(); + std::string output_act_name = fake_dequant_out->Var()->Name(); + auto outlinks = fake_dequant_out->outputs; + for (auto* next_node : outlinks) { + next_node->Op()->RenameInput(output_act_name, input_act_name); + IR_NODE_LINK_TO(fake_dequant_in, next_node); + } + + nodes2rm->insert(op_node); + nodes2rm->insert(fake_dequant_out); +} + +void QuantDequantMkldnnPass::RemoveFakeOps( + ir::Graph* graph, + const std::unordered_set& fake_quantize_types, + const std::unordered_set& fake_dequantize_types, + const std::unordered_set& fake_quantize_dequantize_types) + const { + VLOG(3) << "remove fake quantize and dequantize ops"; + + std::unordered_set nodes2rm = {}; + for (auto* op_node : + ir::TopologyVarientSort(*graph, static_cast(0))) { + if (!op_node->IsOp()) continue; + + if (fake_quantize_types.count(op_node->Name())) { + CollectFakeQuantizeOps(graph, op_node, &nodes2rm); + } else if (fake_dequantize_types.count(op_node->Name())) { + CollectFakeDequantizeOps(graph, op_node, &nodes2rm); + } else if (fake_quantize_dequantize_types.count(op_node->Name())) { + CollectFakeDequantizeOps(graph, op_node, &nodes2rm); + } + } + + GraphSafeRemoveNodes(graph, nodes2rm); +} + +void QuantDequantMkldnnPass::TransposeWeight(Tensor* input) const { + const auto in_dims = input->dims(); + std::vector out_dim_v; + std::vector axis; + for (int i = in_dims.size() - 1; i >= 0; i--) { + axis.push_back(i); + out_dim_v.push_back(in_dims[i]); + } + + const auto out_dims = phi::make_ddim(out_dim_v); + const int rank = axis.size(); + auto in_stride = phi::stride(in_dims); + auto out_stride = phi::stride(out_dims); + const int count = input->numel(); + + Tensor trans_tensor; + trans_tensor.Resize(out_dims); + float* trans_data = trans_tensor.mutable_data(platform::CPUPlace()); + float* in_data = input->mutable_data(platform::CPUPlace()); + + for (int64_t out_idx = 0; out_idx < count; ++out_idx) { + int64_t in_idx = 0; + int64_t tmp_idx = out_idx; + for (int i = 0; i < rank; ++i) { + const int64_t coordinate = tmp_idx / out_stride[i]; + tmp_idx -= coordinate * out_stride[i]; + in_idx += coordinate * in_stride[axis[i]]; + } + trans_data[out_idx] = in_data[in_idx]; + } + + input->Resize(out_dims); + for (int i = 0; i < input->numel(); i++) { + in_data[i] = trans_data[i]; + } +} + +bool QuantDequantMkldnnPass::IsInt8Weight( + Node* op_node, Scope* scope, const std::string& weight_name) const { + auto* op_desc = op_node->Op(); + auto var_name = op_desc->Input(weight_name)[0]; + auto* var = scope->FindVar(var_name); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::NotFound( + "The input persistable [%s] var of [%s] op is not found.", + var_name, op_desc->Type())); + auto* weight_tensor = var->GetMutable(); + auto* weight_data = weight_tensor->data(); + bool is_int8 = true; + for (int i = 0; i < weight_tensor->numel(); i++) { + if (weight_data[i] - static_cast(weight_data[i]) != 0) { + is_int8 = false; + break; + } + } + return is_int8; +} + +void QuantDequantMkldnnPass::DequantizeOpWeights( + Node* op_node, Scope* scope, const std::string& weight_name, + const std::string& output_name, + const std::unordered_map>& + weight_thresholds) const { + auto* op_desc = op_node->Op(); + std::string weight_var_name = op_desc->Input(weight_name)[0]; + std::string output_var_name = op_desc->Output(output_name)[0]; + + std::vector scales; + auto iter = weight_thresholds.find(output_var_name); + if (iter != weight_thresholds.end()) { + scales = iter->second; + } else { + PADDLE_THROW(paddle::platform::errors::Fatal( + "Could not find threshold information for [%s] var, please check if " + "the model is correct.", + output_var_name)); + } + + auto* var = scope->FindVar(weight_var_name); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::NotFound( + "The input persistable [%s] var of [%s] op is not found.", + weight_var_name, op_desc->Type())); + auto* weight_tensor = var->GetMutable(); + const auto weight_dims = weight_tensor->dims(); + + const int size = scales.size(); + if (size == 1 || size == weight_dims[0]) { + auto* weight_data = + weight_tensor->mutable_data(platform::CPUPlace()); + for (int i = 0; i < weight_tensor->numel(); i++) { + weight_data[i] /= 127; + } + + TransposeWeight(weight_tensor); + + if (size == 1) { + for (int i = 0; i < weight_tensor->numel(); i++) { + weight_data[i] *= scales[0]; + } + } else { + for (int i = 0; i < weight_tensor->numel(); i++) { + weight_data[i] *= scales[i % size]; + } + } + + TransposeWeight(weight_tensor); + } else if (weight_dims.size() > 1 && size == weight_dims[1]) { + auto* weight_data = + weight_tensor->mutable_data(platform::CPUPlace()); + for (int i = 0; i < weight_tensor->numel(); i++) { + weight_data[i] /= 127; + } + + int step_n = 1; + for (int i = 1; i < weight_dims.size(); i++) { + step_n *= weight_dims[i]; + } + int step_c = step_n / size; + for (int i = 0; i < weight_dims[0]; i++) { + int begin_n = i * step_n; + for (int j = begin_n; j < begin_n + step_n; j++) { + for (int k = 0; k < size; k++) { + int begin_c = k * step_c; + for (int m = begin_c; m < begin_c + step_c; m++) { + weight_data[m] *= scales[k]; + } + } + } + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The size of weight scales vector (%d) does not " + "match the dimensions (%d) of the weights tensor %s.", + size, weight_tensor->dims().size(), weight_var_name)); + } + + weight_tensor->Resize(weight_dims); +} + +void QuantDequantMkldnnPass::DequantizeWeights( + ir::Graph* graph, Scope* scope, + const std::unordered_map>& + weight_thresholds) const { + VLOG(3) << "dequantize weight for ops which has weight"; + + if (weight_thresholds.empty()) { + VLOG(3) + << "No need to dequantize weights because weight_thresholds is empty."; + return; + } + + for (auto* op_node : + ir::TopologyVarientSort(*graph, static_cast(0))) { + if (!op_node->IsOp()) continue; + if (op_node->Name() == "conv2d" || op_node->Name() == "depthwise_conv2d") { + if (IsInt8Weight(op_node, scope, "Filter")) { + DequantizeOpWeights(op_node, scope, "Filter", "Output", + weight_thresholds); + } + } else if (op_node->Name() == "mul" || op_node->Name() == "matmul" || + op_node->Name() == "matmul_v2") { + if (IsInt8Weight(op_node, scope, "Y")) { + DequantizeOpWeights(op_node, scope, "Y", "Out", weight_thresholds); + } + } + } +} + +void QuantDequantMkldnnPass::UpdateActivations(ir::Graph* graph) const { + VLOG(3) << "update conv2d or depthwise_conv2d fused activation"; + for (auto* op_node : + ir::TopologyVarientSort(*graph, static_cast(0))) { + if (!op_node->IsOp()) continue; + + if (op_node->Name() == "conv2d" || op_node->Name() == "depthwise_conv2d") { + auto* op_desc = op_node->Op(); + if (!op_desc->HasAttr("fuse_activation")) { + std::string activation; + if (op_desc->GetAttrIfExists("fuse_relu")) { + activation = "relu"; + } else if (op_desc->GetAttrIfExists("fuse_brelu")) { + activation = "relu6"; + float alpha = 6.0; + if (op_desc->HasAttr("fuse_brelu_threshold")) { + alpha = BOOST_GET_CONST(float, + op_desc->GetAttr("fuse_brelu_threshold")); + } + op_node->Op()->SetAttr("fuse_alpha", alpha); + } + op_node->Op()->SetAttr("fuse_activation", activation); + } + } + } +} + +void QuantDequantMkldnnPass::RemoveCtrlVars(ir::Graph* graph) const { + VLOG(3) << "remove control flow variable"; + std::unordered_set nodes2rm = {}; + for (auto* op_node : + ir::TopologyVarientSort(*graph, static_cast(0))) { + if (op_node->IsCtrlVar()) { + nodes2rm.insert(op_node); + } + } + + GraphSafeRemoveNodes(graph, nodes2rm); +} + +void QuantDequantMkldnnPass::ApplyImpl(ir::Graph* graph) const { + VLOG(3) << "Convert paddle slim quantized model to mkldnn quantized model."; + const std::string pattern_name = "quant_dequant_mkldnn_pass"; + FusePassBase::Init(pattern_name, graph); + + const std::unordered_set skip_ops = { + "conv2d", "depthwise_conv2d", "mul", "matmul", "matmul_v2"}; + + const std::unordered_set fake_quantize_types = { + "fake_quantize_moving_average_abs_max", "fake_quantize_range_abs_max"}; + + const std::unordered_set fake_dequantize_types = { + "fake_dequantize_max_abs", "fake_channel_wise_dequantize_max_abs"}; + + const std::unordered_set fake_quantize_dequantize_types = { + "fake_quantize_dequantize_abs_max", + "fake_quantize_dequantize_moving_average_abs_max", + "fake_channel_wise_quantize_dequantize_abs_max"}; + + std::unordered_map> weight_thresholds{}; + std::unordered_map> var_quant_scales{}; + + auto* scope = param_scope(); + MarkSkipQuantizedOps(graph, skip_ops); + MarkSkipQuantizedPool2d(graph); + CollectInfoFromFake(graph, scope, fake_dequantize_types, &weight_thresholds); + CollectInputScalesFromFake(graph, scope, fake_quantize_types, + &var_quant_scales); + CollectOutputScalesFromAttr(graph, &var_quant_scales); + RemoveFakeOps(graph, fake_quantize_types, fake_dequantize_types, + fake_quantize_dequantize_types); + DequantizeWeights(graph, scope, weight_thresholds); + UpdateActivations(graph); + RemoveCtrlVars(graph); + + // save var_quant_scales in the first op's attr + // for compute_propagate_scales_mkldnn_pass + SaveInfoInTheFirstOp(graph, "has_quant_info", "var_quant_scales", + var_quant_scales); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(quant_dequant_mkldnn_pass, + paddle::framework::ir::QuantDequantMkldnnPass); + +REGISTER_PASS_CAPABILITY(quant_dequant_mkldnn_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("conv2d", 1) + .EQ("fc", 0) + .LE("conv2d_transpose", 2) + .EQ("fake_quantize_abs_max", 0) + .EQ("fake_quantize_range_abs_max", 0) + .EQ("fake_quantize_moving_average_abs_max", 0) + .LE("fake_channel_wise_quantize_abs_max", 1) + .EQ("fake_dequantize_max_abs", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.h b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.h new file mode 100644 index 0000000000..a9442f7074 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.h @@ -0,0 +1,91 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class QuantDequantMkldnnPass : public FusePassBase { + public: + QuantDequantMkldnnPass() = default; + virtual ~QuantDequantMkldnnPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + void MarkSkipQuantizedOps( + ir::Graph* graph, const std::unordered_set& skip_ops) const; + + void MarkSkipQuantizedPool2d(ir::Graph* graph) const; + + void CollectInfoFromFake( + ir::Graph* graph, Scope* scope, + const std::unordered_set& fake_dequantize_types, + std::unordered_map>* weight_thresholds) + const; + + void CollectInputScalesFromFake( + ir::Graph* graph, Scope* scope, + const std::unordered_set& fake_quantize_types, + std::unordered_map>* var_quant_scales) + const; + + void CollectOutputScalesFromAttr( + ir::Graph* graph, + std::unordered_map>* var_quant_scales) + const; + + void CollectFakeQuantizeOps(ir::Graph* graph, Node* op_node, + std::unordered_set* nodes2rm) const; + + void CollectFakeDequantizeOps( + ir::Graph* graph, Node* op_node, + std::unordered_set* nodes2rm) const; + + void RemoveFakeOps( + ir::Graph* graph, + const std::unordered_set& fake_quantize_types, + const std::unordered_set& fake_dequantize_types, + const std::unordered_set& fake_quantize_dequantize_types) + const; + + bool IsInt8Weight(Node* op_node, Scope* scope, + const std::string& weight_name) const; + + void TransposeWeight(Tensor* input) const; + + void DequantizeOpWeights( + Node* op_node, Scope* scope, const std::string& weight_name, + const std::string& output_name, + const std::unordered_map>& + weight_thresholds) const; + + void DequantizeWeights( + ir::Graph* graph, Scope* scope, + const std::unordered_map>& + weight_thresholds) const; + + void UpdateActivations(ir::Graph* graph) const; + + void RemoveCtrlVars(ir::Graph* graph) const; +}; +} // namespace ir +} // namespace framework +} // namespace paddle -- GitLab