diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 611b1bb5eb8b08b7097d64a2fb485f3b9b3e35c0..96f88e70a98d453cbf7d74e6f17194855003555a 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -195,32 +195,73 @@ void FuseDequant(ir::Graph* graph, Scope* scope, auto* weight_tensor = scope->Var(quantized_op_weight_node->Name())->GetMutable(); auto w_dims = weight_tensor->dims(); + float* quantized_weight_data = + weight_tensor->mutable_data(platform::CPUPlace()); // If quantized op is fc, weight scale size = 1; // If quantized op is conv2d, weight scale size = weight dims[0] // If quantized op is conv2d_transpose, weight scale size = weight dims[1] - bool valid_scale_size = - (weight_scale.size() == 1 || - weight_scale.size() == static_cast(w_dims[0]) || - weight_scale.size() == static_cast(w_dims[1])); - PADDLE_ENFORCE_EQ( - valid_scale_size, true, - platform::errors::InvalidArgument( - "TRT int8 quant: invalid scale size(%d).", weight_scale.size())); - float* quantized_weight_data = - weight_tensor->mutable_data(platform::CPUPlace()); - for (int j = 0; j < weight_tensor->numel(); j++) { - if (weight_scale.size() == 1) { - quantized_weight_data[j] *= weight_scale[0]; - } else { - if (quantized_op_type == "conv2d_transpose") { - int inner_size = w_dims[2] * w_dims[3]; - quantized_weight_data[j] *= - weight_scale[(j / inner_size) % w_dims[1]]; - } else { - int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; - quantized_weight_data[j] *= weight_scale[j / inner_size]; + if (quantized_op_type == "mul" || quantized_op_type == "fc") { + if (dequant_type == "fake_dequantize_max_abs") { + PADDLE_ENFORCE_EQ( + weight_scale.size(), 1, + platform::errors::InvalidArgument( + "mul op weight dequantized by [fake_dequantize_max_abs] " + "requires weight scale size = 1, but got %d.", + weight_scale.size())); + for (int j = 0; j < weight_tensor->numel(); j++) { + quantized_weight_data[j] *= weight_scale[0]; } } + if (dequant_type == "fake_channel_wise_dequantize_max_abs") { + PADDLE_ENFORCE_EQ( + weight_scale.size(), static_cast(w_dims[1]), + platform::errors::InvalidArgument( + "mul op weight dequantized by " + "[fake_channel_wise_dequantize_max_abs] requires weight scale " + "size = 2nd dim of mul's weight, which is %d, but got %d.", + static_cast(w_dims[1]), weight_scale.size())); + for (int j = 0; j < weight_tensor->numel(); j++) { + quantized_weight_data[j] *= weight_scale[j % w_dims[1]]; + } + } + } else if (quantized_op_type == "conv2d" || + quantized_op_type == "depthwise_conv2d") { + PADDLE_ENFORCE_EQ( + dequant_type, "fake_channel_wise_dequantize_max_abs", + platform::errors::InvalidArgument("conv2d op must be dequantized by " + "[fake_channel_wise_dequantize_max_" + "abs], but got %s", + dequant_type)); + PADDLE_ENFORCE_EQ( + weight_scale.size(), static_cast(w_dims[0]), + platform::errors::InvalidArgument( + "conv2d op requires weight scale size = channel size of the " + "weight, which is %d, but got %d.", + static_cast(w_dims[0]), weight_scale.size())); + for (int j = 0; j < weight_tensor->numel(); j++) { + int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; + quantized_weight_data[j] *= weight_scale[j / inner_size]; + } + } else if (quantized_op_type == "conv2d_transpose") { + PADDLE_ENFORCE_EQ( + dequant_type, "fake_channel_wise_dequantize_max_abs", + platform::errors::InvalidArgument( + "conv2d_transpose must be dequantized by " + "[fake_channel_wise_dequantize_max_abs], but got %s", + dequant_type)); + PADDLE_ENFORCE_EQ( + weight_scale.size(), static_cast(w_dims[1]), + platform::errors::InvalidArgument( + "conv2d_transpose op requires weight scale size = channel size " + "of the weight, which is %d, but got %d.", + static_cast(w_dims[1]), weight_scale.size())); + for (int j = 0; j < weight_tensor->numel(); j++) { + int inner_size = w_dims[2] * w_dims[3]; + quantized_weight_data[j] *= weight_scale[(j / inner_size) % w_dims[1]]; + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported quantized op type: %s", quantized_op_type)); } // create new op_desc