未验证 提交 011a6a51 编写于 作者: A alncat 提交者: GitHub

added support for fake_quantize_dequantize_abs_max op in quantization… (#30896) (#31162)

* added support for fake_quantize_dequantize_abs_max op in quantization inference pass

* remove const_cast to pass ci

* remove compare operator to pass ci-coverage

* added detailed error message for unregistered tensorrt_subgrah_pass
上级 b0ec6e84
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h" #include "paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h"
#include <algorithm>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
...@@ -75,6 +76,12 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -75,6 +76,12 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
any_op2_desc->Flush(); any_op2_desc->Flush();
auto dequant_type = quant_dequant_op->Op()->Type(); auto dequant_type = quant_dequant_op->Op()->Type();
auto quantized_op_type = any_op2_desc->Type(); auto quantized_op_type = any_op2_desc->Type();
// get weight tensor
auto* weight_tensor =
scope->GetVar(quant_dequant_op_x->Name())->GetMutable<LoDTensor>();
auto w_dims = weight_tensor->dims();
float* quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
// Get weight scale // Get weight scale
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") { if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
...@@ -90,26 +97,64 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -90,26 +97,64 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
paddle::platform::is_cpu_place(channel_scale_tensor.place()), paddle::platform::is_cpu_place(channel_scale_tensor.place()),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Channel scale tensor's place should be CPU.")); "Channel scale tensor's place should be CPU."));
const float* channel_scale_data = channel_scale_tensor.data<float>(); // compute the channel wise abs max of the weight tensor
for (int i = 0; i < channel_scale_tensor.numel(); i++) { int quant_axis =
weight_scale.push_back(range / channel_scale_data[i]); BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis"));
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
const int64_t channel = w_dims[quant_axis];
weight_scale.resize(channel, 0);
if (quant_axis == 0) {
const int64_t channel_size = weight_tensor->numel() / channel;
for (int64_t i = 0; i < channel; i++) {
auto* start = quantized_weight_data + i * channel_size;
for (int64_t j = 0; j < channel_size; j++) {
weight_scale[i] = std::max(std::abs(start[j]), weight_scale[i]);
}
}
} else if (quant_axis == 1) {
const int64_t step_i = weight_tensor->numel() / w_dims[0];
const int64_t step_j = weight_tensor->numel() / (w_dims[0] * w_dims[1]);
for (int64_t i = 0; i < w_dims[0]; i++) {
for (int64_t j = 0; j < w_dims[1]; j++) {
auto* start = quantized_weight_data + i * step_i + j * step_j;
float abs_max = 0;
for (int64_t k = 0; k < step_j; k++) {
abs_max = std::max(std::abs(start[k]), abs_max);
}
weight_scale[j] = std::max(weight_scale[j], abs_max);
}
}
}
for (int i = 0; i < channel; i++) {
PADDLE_ENFORCE_NE(weight_scale[i], 0,
platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero."));
weight_scale[i] = range / weight_scale[i];
} }
} else { } else {
auto scale_name = quant_dequant_op_outscale->Name(); auto scale_name = quant_dequant_op_outscale->Name();
const LoDTensor& scale_tensor = // compute the abs max of the weight tensor
scope->GetVar(scale_name)->Get<LoDTensor>(); float abs_max_weight = 0.;
const float* scale_data = scale_tensor.data<float>(); for (int j = 0; j < weight_tensor->numel(); j++) {
weight_scale.push_back((range * range) / scale_data[0] / range); abs_max_weight =
std::max(abs_max_weight, std::abs(quantized_weight_data[j]));
}
PADDLE_ENFORCE_NE(abs_max_weight, 0,
platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero"));
weight_scale.push_back((range * range) / abs_max_weight / range);
} }
nodes2rm.insert(quant_dequant_op_outscale); nodes2rm.insert(quant_dequant_op_outscale);
// perform quantize dequantize operations // perform quantize dequantize operations
auto* weight_tensor = // If quantized op is not channel wise, weight scale size = 1;
scope->GetVar(quant_dequant_op_x->Name())->GetMutable<LoDTensor>();
auto w_dims = weight_tensor->dims();
float* quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
// If quantized op is fc, weight scale size = 1;
// If quantized op is conv2d, weight scale size = weight dims[0] // If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1] // If quantized op is conv2d_transpose, weight scale size = weight dims[1]
if (dequant_type == "fake_quantize_dequantize_abs_max") { if (dequant_type == "fake_quantize_dequantize_abs_max") {
...@@ -119,9 +164,6 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -119,9 +164,6 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
"%s op weight dequantized by [fake_quantize_dequantize_max_abs] " "%s op weight dequantized by [fake_quantize_dequantize_max_abs] "
"requires weight scale size = 1, but got %d.", "requires weight scale size = 1, but got %d.",
quantized_op_type, weight_scale.size())); quantized_op_type, weight_scale.size()));
PADDLE_ENFORCE_NE(weight_scale[0], 0,
platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero"));
for (int j = 0; j < weight_tensor->numel(); j++) { for (int j = 0; j < weight_tensor->numel(); j++) {
// quantized // quantized
quantized_weight_data[j] = quantized_weight_data[j] * weight_scale[0]; quantized_weight_data[j] = quantized_weight_data[j] * weight_scale[0];
......
...@@ -206,9 +206,19 @@ class PassRegistry { ...@@ -206,9 +206,19 @@ class PassRegistry {
} }
std::unique_ptr<Pass> Get(const std::string &pass_type) const { std::unique_ptr<Pass> Get(const std::string &pass_type) const {
PADDLE_ENFORCE_EQ(Has(pass_type), true, if (pass_type == "tensorrt_subgraph_pass") {
platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(Has(pass_type), true,
"Pass %s has not been registered.", pass_type)); platform::errors::InvalidArgument(
"Pass %s has not been registered. Please "
"use the paddle inference library "
"compiled with tensorrt or disable "
"the tensorrt engine in inference configuration! ",
pass_type));
} else {
PADDLE_ENFORCE_EQ(Has(pass_type), true,
platform::errors::InvalidArgument(
"Pass %s has not been registered.", pass_type));
}
return map_.at(pass_type)(); return map_.at(pass_type)();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册