未验证 提交 e3f8e5cf 编写于 作者: P Pei Yang 提交者: GitHub

trt int8 support conv2d_transpose (#26636)

上级 30aab177
...@@ -81,7 +81,8 @@ void DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -81,7 +81,8 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
if (quantized_op_type == "conv2d" || if (quantized_op_type == "conv2d" ||
quantized_op_type == "conv2d_fusion" || quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "depthwise_conv2d" || quantized_op_type == "depthwise_conv2d" ||
quantized_op_type == "fc") { quantized_op_type == "fc" ||
quantized_op_type == "conv2d_transpose") {
op_desc->SetAttr("Input_scale", scale_value); op_desc->SetAttr("Input_scale", scale_value);
} else if (quantized_op_type == "mul") { } else if (quantized_op_type == "mul") {
op_desc->SetAttr("X_scale", scale_value); op_desc->SetAttr("X_scale", scale_value);
...@@ -111,7 +112,8 @@ void FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -111,7 +112,8 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
std::string input_name = ""; std::string input_name = "";
if (quantized_op_type == "conv2d" || if (quantized_op_type == "conv2d" ||
quantized_op_type == "depthwise_conv2d" || quantized_op_type == "depthwise_conv2d" ||
quantized_op_type == "conv2d_fusion") { quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "conv2d_transpose") {
weight_name = "Filter"; weight_name = "Filter";
input_name = "Input"; input_name = "Input";
} else if (quantized_op_type == "mul") { } else if (quantized_op_type == "mul") {
...@@ -122,7 +124,8 @@ void FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -122,7 +124,8 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
input_name = "Input"; input_name = "Input";
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"QuantDequantFuse: We only support conv2d, conv2d_fusion, fc, mul for " "QuantDequantFuse: We only support conv2d, conv2d_fusion, "
"conv2d_transpose, fc, mul for "
"now.")); "now."));
} }
const std::string pattern_name = "dequant_fuse"; const std::string pattern_name = "dequant_fuse";
...@@ -192,10 +195,12 @@ void FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -192,10 +195,12 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
scope->Var(quantized_op_weight_node->Name())->GetMutable<LoDTensor>(); scope->Var(quantized_op_weight_node->Name())->GetMutable<LoDTensor>();
auto w_dims = weight_tensor->dims(); auto w_dims = weight_tensor->dims();
// If quantized op is fc, weight scale size = 1; // If quantized op is fc, weight scale size = 1;
// If quantized op is conv, weight scale size = weight dims[0] // If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
bool valid_scale_size = bool valid_scale_size =
(weight_scale.size() == 1 || (weight_scale.size() == 1 ||
weight_scale.size() == static_cast<size_t>(w_dims[0])); weight_scale.size() == static_cast<size_t>(w_dims[0]) ||
weight_scale.size() == static_cast<size_t>(w_dims[1]));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
valid_scale_size, true, valid_scale_size, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -206,8 +211,14 @@ void FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -206,8 +211,14 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
if (weight_scale.size() == 1) { if (weight_scale.size() == 1) {
quantized_weight_data[j] *= weight_scale[0]; quantized_weight_data[j] *= weight_scale[0];
} else { } else {
int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; if (quantized_op_type == "conv2d_transpose") {
quantized_weight_data[j] *= weight_scale[j / inner_size]; int inner_size = w_dims[2] * w_dims[3];
quantized_weight_data[j] *=
weight_scale[(j / inner_size) % w_dims[1]];
} else {
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[j / inner_size];
}
} }
} }
...@@ -220,7 +231,8 @@ void FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -220,7 +231,8 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
new_op_desc.SetType(quantized_op_type); new_op_desc.SetType(quantized_op_type);
new_op_desc.SetAttr("enable_int8", true); new_op_desc.SetAttr("enable_int8", true);
if (quantized_op_type == "conv2d" || quantized_op_type == "conv2d_fusion" || if (quantized_op_type == "conv2d" || quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "depthwise_conv2d") { quantized_op_type == "depthwise_conv2d" ||
quantized_op_type == "conv2d_transpose") {
new_op_desc.SetInput("Input", {new_input}); new_op_desc.SetInput("Input", {new_input});
new_op_desc.SetOutput("Output", {new_output}); new_op_desc.SetOutput("Output", {new_output});
} else if (quantized_op_type == "fc") { } else if (quantized_op_type == "fc") {
...@@ -253,7 +265,7 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -253,7 +265,7 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
std::unordered_set<std::string> quant_types = { std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::unordered_set<std::string> quantized_op_types = { std::unordered_set<std::string> quantized_op_types = {
"conv2d", "mul", "depthwise_conv2d", "fc"}; "conv2d", "mul", "depthwise_conv2d", "fc", "conv2d_transpose"};
auto* scope = param_scope(); auto* scope = param_scope();
for (auto& quant_type : quant_types) { for (auto& quant_type : quant_types) {
......
...@@ -51,7 +51,13 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, ...@@ -51,7 +51,13 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
if (enable_int8) { if (enable_int8) {
#if IS_TRT_VERSION_GE(5000) #if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr("Input_scale")); if (op_desc.Type() != "conv2d_transpose") {
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("Input_scale"), true,
platform::errors::InvalidArgument("Input scale not found. TRT int8"
" requires conv/deconv to have "
"input quantization scales."));
}
float in_scale = float in_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127; BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127;
auto weight_scale = auto weight_scale =
......
...@@ -68,6 +68,7 @@ _out_scale_op_list = [ ...@@ -68,6 +68,7 @@ _out_scale_op_list = [
"scale", "scale",
"hard_swish", "hard_swish",
"hard_sigmoid", "hard_sigmoid",
"conv2d_transpose",
] ]
# list op real input and output names, to avoid processing input such as AxisTensor. # list op real input and output names, to avoid processing input such as AxisTensor.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册