未验证 提交 75196cda 编写于 作者: P Pei Yang 提交者: GitHub

Paddle-TRT int8 support mul op channelwise quant (#28422)

* paddle-trt support mul channelwise quant

* add support for depthwise_conv2d

* add errmsg for unsupported op type
上级 c70c1c52
...@@ -195,32 +195,73 @@ void FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -195,32 +195,73 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
auto* weight_tensor = auto* weight_tensor =
scope->Var(quantized_op_weight_node->Name())->GetMutable<LoDTensor>(); scope->Var(quantized_op_weight_node->Name())->GetMutable<LoDTensor>();
auto w_dims = weight_tensor->dims(); auto w_dims = weight_tensor->dims();
float* quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
// If quantized op is fc, weight scale size = 1; // If quantized op is fc, weight scale size = 1;
// If quantized op is conv2d, weight scale size = weight dims[0] // If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1] // If quantized op is conv2d_transpose, weight scale size = weight dims[1]
bool valid_scale_size = if (quantized_op_type == "mul" || quantized_op_type == "fc") {
(weight_scale.size() == 1 || if (dequant_type == "fake_dequantize_max_abs") {
weight_scale.size() == static_cast<size_t>(w_dims[0]) ||
weight_scale.size() == static_cast<size_t>(w_dims[1]));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
valid_scale_size, true, weight_scale.size(), 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"TRT int8 quant: invalid scale size(%d).", weight_scale.size())); "mul op weight dequantized by [fake_dequantize_max_abs] "
float* quantized_weight_data = "requires weight scale size = 1, but got %d.",
weight_tensor->mutable_data<float>(platform::CPUPlace()); weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) { for (int j = 0; j < weight_tensor->numel(); j++) {
if (weight_scale.size() == 1) {
quantized_weight_data[j] *= weight_scale[0]; quantized_weight_data[j] *= weight_scale[0];
} else { }
if (quantized_op_type == "conv2d_transpose") { }
int inner_size = w_dims[2] * w_dims[3]; if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
quantized_weight_data[j] *= PADDLE_ENFORCE_EQ(
weight_scale[(j / inner_size) % w_dims[1]]; weight_scale.size(), static_cast<size_t>(w_dims[1]),
} else { platform::errors::InvalidArgument(
"mul op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs] requires weight scale "
"size = 2nd dim of mul's weight, which is %d, but got %d.",
static_cast<size_t>(w_dims[1]), weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
quantized_weight_data[j] *= weight_scale[j % w_dims[1]];
}
}
} else if (quantized_op_type == "conv2d" ||
quantized_op_type == "depthwise_conv2d") {
PADDLE_ENFORCE_EQ(
dequant_type, "fake_channel_wise_dequantize_max_abs",
platform::errors::InvalidArgument("conv2d op must be dequantized by "
"[fake_channel_wise_dequantize_max_"
"abs], but got %s",
dequant_type));
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[0]),
platform::errors::InvalidArgument(
"conv2d op requires weight scale size = channel size of the "
"weight, which is %d, but got %d.",
static_cast<size_t>(w_dims[0]), weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[j / inner_size]; quantized_weight_data[j] *= weight_scale[j / inner_size];
} }
} else if (quantized_op_type == "conv2d_transpose") {
PADDLE_ENFORCE_EQ(
dequant_type, "fake_channel_wise_dequantize_max_abs",
platform::errors::InvalidArgument(
"conv2d_transpose must be dequantized by "
"[fake_channel_wise_dequantize_max_abs], but got %s",
dequant_type));
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument(
"conv2d_transpose op requires weight scale size = channel size "
"of the weight, which is %d, but got %d.",
static_cast<size_t>(w_dims[1]), weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
int inner_size = w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[(j / inner_size) % w_dims[1]];
} }
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
} }
// create new op_desc // create new op_desc
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册