未验证 提交 e3f8e5cf 编写于 作者: P Pei Yang 提交者: GitHub

trt int8 support conv2d_transpose (#26636)

上级 30aab177
......@@ -81,7 +81,8 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
if (quantized_op_type == "conv2d" ||
quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "depthwise_conv2d" ||
quantized_op_type == "fc") {
quantized_op_type == "fc" ||
quantized_op_type == "conv2d_transpose") {
op_desc->SetAttr("Input_scale", scale_value);
} else if (quantized_op_type == "mul") {
op_desc->SetAttr("X_scale", scale_value);
......@@ -111,7 +112,8 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
std::string input_name = "";
if (quantized_op_type == "conv2d" ||
quantized_op_type == "depthwise_conv2d" ||
quantized_op_type == "conv2d_fusion") {
quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "conv2d_transpose") {
weight_name = "Filter";
input_name = "Input";
} else if (quantized_op_type == "mul") {
......@@ -122,7 +124,8 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
input_name = "Input";
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"QuantDequantFuse: We only support conv2d, conv2d_fusion, fc, mul for "
"QuantDequantFuse: We only support conv2d, conv2d_fusion, "
"conv2d_transpose, fc, mul for "
"now."));
}
const std::string pattern_name = "dequant_fuse";
......@@ -192,10 +195,12 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
scope->Var(quantized_op_weight_node->Name())->GetMutable<LoDTensor>();
auto w_dims = weight_tensor->dims();
// If quantized op is fc, weight scale size = 1;
// If quantized op is conv, weight scale size = weight dims[0]
// If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
bool valid_scale_size =
(weight_scale.size() == 1 ||
weight_scale.size() == static_cast<size_t>(w_dims[0]));
weight_scale.size() == static_cast<size_t>(w_dims[0]) ||
weight_scale.size() == static_cast<size_t>(w_dims[1]));
PADDLE_ENFORCE_EQ(
valid_scale_size, true,
platform::errors::InvalidArgument(
......@@ -206,8 +211,14 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
if (weight_scale.size() == 1) {
quantized_weight_data[j] *= weight_scale[0];
} else {
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[j / inner_size];
if (quantized_op_type == "conv2d_transpose") {
int inner_size = w_dims[2] * w_dims[3];
quantized_weight_data[j] *=
weight_scale[(j / inner_size) % w_dims[1]];
} else {
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[j / inner_size];
}
}
}
......@@ -220,7 +231,8 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
new_op_desc.SetType(quantized_op_type);
new_op_desc.SetAttr("enable_int8", true);
if (quantized_op_type == "conv2d" || quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "depthwise_conv2d") {
quantized_op_type == "depthwise_conv2d" ||
quantized_op_type == "conv2d_transpose") {
new_op_desc.SetInput("Input", {new_input});
new_op_desc.SetOutput("Output", {new_output});
} else if (quantized_op_type == "fc") {
......@@ -253,7 +265,7 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::unordered_set<std::string> quantized_op_types = {
"conv2d", "mul", "depthwise_conv2d", "fc"};
"conv2d", "mul", "depthwise_conv2d", "fc", "conv2d_transpose"};
auto* scope = param_scope();
for (auto& quant_type : quant_types) {
......
......@@ -51,7 +51,13 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
if (enable_int8) {
#if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr("Input_scale"));
if (op_desc.Type() != "conv2d_transpose") {
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("Input_scale"), true,
platform::errors::InvalidArgument("Input scale not found. TRT int8"
" requires conv/deconv to have "
"input quantization scales."));
}
float in_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127;
auto weight_scale =
......
......@@ -68,6 +68,7 @@ _out_scale_op_list = [
"scale",
"hard_swish",
"hard_sigmoid",
"conv2d_transpose",
]
# list op real input and output names, to avoid processing input such as AxisTensor.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册