未验证 提交 75196cda 编写于 作者: P Pei Yang 提交者: GitHub

Paddle-TRT int8 support mul op channelwise quant (#28422)

* paddle-trt support mul channelwise quant

* add support for depthwise_conv2d

* add errmsg for unsupported op type
上级 c70c1c52
......@@ -195,32 +195,73 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
auto* weight_tensor =
scope->Var(quantized_op_weight_node->Name())->GetMutable<LoDTensor>();
auto w_dims = weight_tensor->dims();
float* quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
// If quantized op is fc, weight scale size = 1;
// If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
bool valid_scale_size =
(weight_scale.size() == 1 ||
weight_scale.size() == static_cast<size_t>(w_dims[0]) ||
weight_scale.size() == static_cast<size_t>(w_dims[1]));
PADDLE_ENFORCE_EQ(
valid_scale_size, true,
platform::errors::InvalidArgument(
"TRT int8 quant: invalid scale size(%d).", weight_scale.size()));
float* quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
for (int j = 0; j < weight_tensor->numel(); j++) {
if (weight_scale.size() == 1) {
quantized_weight_data[j] *= weight_scale[0];
} else {
if (quantized_op_type == "conv2d_transpose") {
int inner_size = w_dims[2] * w_dims[3];
quantized_weight_data[j] *=
weight_scale[(j / inner_size) % w_dims[1]];
} else {
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[j / inner_size];
if (quantized_op_type == "mul" || quantized_op_type == "fc") {
if (dequant_type == "fake_dequantize_max_abs") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), 1,
platform::errors::InvalidArgument(
"mul op weight dequantized by [fake_dequantize_max_abs] "
"requires weight scale size = 1, but got %d.",
weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
quantized_weight_data[j] *= weight_scale[0];
}
}
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument(
"mul op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs] requires weight scale "
"size = 2nd dim of mul's weight, which is %d, but got %d.",
static_cast<size_t>(w_dims[1]), weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
quantized_weight_data[j] *= weight_scale[j % w_dims[1]];
}
}
} else if (quantized_op_type == "conv2d" ||
quantized_op_type == "depthwise_conv2d") {
PADDLE_ENFORCE_EQ(
dequant_type, "fake_channel_wise_dequantize_max_abs",
platform::errors::InvalidArgument("conv2d op must be dequantized by "
"[fake_channel_wise_dequantize_max_"
"abs], but got %s",
dequant_type));
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[0]),
platform::errors::InvalidArgument(
"conv2d op requires weight scale size = channel size of the "
"weight, which is %d, but got %d.",
static_cast<size_t>(w_dims[0]), weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[j / inner_size];
}
} else if (quantized_op_type == "conv2d_transpose") {
PADDLE_ENFORCE_EQ(
dequant_type, "fake_channel_wise_dequantize_max_abs",
platform::errors::InvalidArgument(
"conv2d_transpose must be dequantized by "
"[fake_channel_wise_dequantize_max_abs], but got %s",
dequant_type));
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument(
"conv2d_transpose op requires weight scale size = channel size "
"of the weight, which is %d, but got %d.",
static_cast<size_t>(w_dims[1]), weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
int inner_size = w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[(j / inner_size) % w_dims[1]];
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
// create new op_desc
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册