From 011a6a514d483237edcb5f4ba840972714c6da26 Mon Sep 17 00:00:00 2001 From: alncat Date: Wed, 24 Feb 2021 18:35:20 +0800 Subject: [PATCH] =?UTF-8?q?added=20support=20for=20fake=5Fquantize=5Fdequa?= =?UTF-8?q?ntize=5Fabs=5Fmax=20op=20in=20quantization=E2=80=A6=20(#30896)?= =?UTF-8?q?=20(#31162)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * added support for fake_quantize_dequantize_abs_max op in quantization inference pass * remove const_cast to pass ci * remove compare operator to pass ci-coverage * added detailed error message for unregistered tensorrt_subgrah_pass --- .../ir/delete_quant_dequant_filter_op_pass.cc | 74 +++++++++++++++---- paddle/fluid/framework/ir/pass.h | 16 +++- 2 files changed, 71 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc index 8b3606b588a..4379bba6380 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h" +#include #include #include #include @@ -75,6 +76,12 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { any_op2_desc->Flush(); auto dequant_type = quant_dequant_op->Op()->Type(); auto quantized_op_type = any_op2_desc->Type(); + // get weight tensor + auto* weight_tensor = + scope->GetVar(quant_dequant_op_x->Name())->GetMutable(); + auto w_dims = weight_tensor->dims(); + float* quantized_weight_data = + weight_tensor->mutable_data(platform::CPUPlace()); // Get weight scale if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") { @@ -90,26 +97,64 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { paddle::platform::is_cpu_place(channel_scale_tensor.place()), platform::errors::InvalidArgument( "Channel scale tensor's place should be CPU.")); - const float* channel_scale_data = channel_scale_tensor.data(); - for (int i = 0; i < channel_scale_tensor.numel(); i++) { - weight_scale.push_back(range / channel_scale_data[i]); + // compute the channel wise abs max of the weight tensor + int quant_axis = + BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis")); + + PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument( + "'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + + const int64_t channel = w_dims[quant_axis]; + weight_scale.resize(channel, 0); + if (quant_axis == 0) { + const int64_t channel_size = weight_tensor->numel() / channel; + for (int64_t i = 0; i < channel; i++) { + auto* start = quantized_weight_data + i * channel_size; + for (int64_t j = 0; j < channel_size; j++) { + weight_scale[i] = std::max(std::abs(start[j]), weight_scale[i]); + } + } + } else if (quant_axis == 1) { + const int64_t step_i = weight_tensor->numel() / w_dims[0]; + const int64_t step_j = weight_tensor->numel() / (w_dims[0] * w_dims[1]); + for (int64_t i = 0; i < w_dims[0]; i++) { + for (int64_t j = 0; j < w_dims[1]; j++) { + auto* start = quantized_weight_data + i * step_i + j * step_j; + float abs_max = 0; + for (int64_t k = 0; k < step_j; k++) { + abs_max = std::max(std::abs(start[k]), abs_max); + } + weight_scale[j] = std::max(weight_scale[j], abs_max); + } + } + } + for (int i = 0; i < channel; i++) { + PADDLE_ENFORCE_NE(weight_scale[i], 0, + platform::errors::InvalidArgument( + "Weight scale should be nonzero, but get zero.")); + weight_scale[i] = range / weight_scale[i]; } } else { auto scale_name = quant_dequant_op_outscale->Name(); - const LoDTensor& scale_tensor = - scope->GetVar(scale_name)->Get(); - const float* scale_data = scale_tensor.data(); - weight_scale.push_back((range * range) / scale_data[0] / range); + // compute the abs max of the weight tensor + float abs_max_weight = 0.; + for (int j = 0; j < weight_tensor->numel(); j++) { + abs_max_weight = + std::max(abs_max_weight, std::abs(quantized_weight_data[j])); + } + PADDLE_ENFORCE_NE(abs_max_weight, 0, + platform::errors::InvalidArgument( + "Weight scale should be nonzero, but get zero")); + weight_scale.push_back((range * range) / abs_max_weight / range); } nodes2rm.insert(quant_dequant_op_outscale); + // perform quantize dequantize operations - auto* weight_tensor = - scope->GetVar(quant_dequant_op_x->Name())->GetMutable(); - auto w_dims = weight_tensor->dims(); - float* quantized_weight_data = - weight_tensor->mutable_data(platform::CPUPlace()); - // If quantized op is fc, weight scale size = 1; + // If quantized op is not channel wise, weight scale size = 1; // If quantized op is conv2d, weight scale size = weight dims[0] // If quantized op is conv2d_transpose, weight scale size = weight dims[1] if (dequant_type == "fake_quantize_dequantize_abs_max") { @@ -119,9 +164,6 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { "%s op weight dequantized by [fake_quantize_dequantize_max_abs] " "requires weight scale size = 1, but got %d.", quantized_op_type, weight_scale.size())); - PADDLE_ENFORCE_NE(weight_scale[0], 0, - platform::errors::InvalidArgument( - "Weight scale should be nonzero, but get zero")); for (int j = 0; j < weight_tensor->numel(); j++) { // quantized quantized_weight_data[j] = quantized_weight_data[j] * weight_scale[0]; diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index a3b1b33d268..9c306479bf5 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -206,9 +206,19 @@ class PassRegistry { } std::unique_ptr Get(const std::string &pass_type) const { - PADDLE_ENFORCE_EQ(Has(pass_type), true, - platform::errors::InvalidArgument( - "Pass %s has not been registered.", pass_type)); + if (pass_type == "tensorrt_subgraph_pass") { + PADDLE_ENFORCE_EQ(Has(pass_type), true, + platform::errors::InvalidArgument( + "Pass %s has not been registered. Please " + "use the paddle inference library " + "compiled with tensorrt or disable " + "the tensorrt engine in inference configuration! ", + pass_type)); + } else { + PADDLE_ENFORCE_EQ(Has(pass_type), true, + platform::errors::InvalidArgument( + "Pass %s has not been registered.", pass_type)); + } return map_.at(pass_type)(); } -- GitLab