From 38faed7fca3f29bf9f94e4ebc57a21977095cd4b Mon Sep 17 00:00:00 2001 From: alncat Date: Thu, 14 Jan 2021 13:50:16 +0800 Subject: [PATCH] Added support for inference using quantization aware trained dygraph (#30288) (#30402) --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../ir/conv_elementwise_add_fuse_pass.cc | 8 + .../ir/delete_quant_dequant_filter_op_pass.cc | 237 ++++++++++++++++++ .../ir/delete_quant_dequant_filter_op_pass.h | 37 +++ .../ir/delete_quant_dequant_op_pass.cc | 4 +- paddle/fluid/framework/ir/fc_fuse_pass.cc | 12 + .../framework/ir/graph_pattern_detector.cc | 58 +++++ .../framework/ir/graph_pattern_detector.h | 30 +++ .../framework/ir/map_matmul_to_mul_pass.cc | 104 +++++++- .../framework/ir/map_matmul_to_mul_pass.h | 8 + paddle/fluid/framework/scope.cc | 7 + paddle/fluid/framework/scope.h | 4 + paddle/fluid/inference/api/analysis_config.cc | 2 +- .../inference/api/paddle_pass_builder.cc | 6 +- .../inference/tensorrt/convert/conv2d_op.cc | 13 +- .../fluid/inference/tensorrt/convert/fc_op.cc | 11 +- paddle/fluid/inference/tensorrt/op_teller.cc | 2 + 17 files changed, 534 insertions(+), 10 deletions(-) create mode 100644 paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc create mode 100644 paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 397ab88fb79..4b85a7d129b 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -85,6 +85,7 @@ pass_library(runtime_context_cache_pass base) pass_library(quant_conv2d_dequant_fuse_pass inference) pass_library(shuffle_channel_detect_pass inference) pass_library(delete_quant_dequant_op_pass inference) +pass_library(delete_quant_dequant_filter_op_pass inference) pass_library(simplify_with_basic_ops_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(skip_layernorm_fuse_pass base) diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc index 9121047d2fa..bbe66baee2f 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc @@ -62,6 +62,14 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const { new_op_desc.SetOutput("Output", {output_name}); new_op_desc.SetAttr("is_test", true); new_op_desc.SetAttr("use_cudnn", false); + auto* elementwise_add_op_desc = elementwise_add_op->Op(); + auto out_threshold_attr = + elementwise_add_op_desc->GetNullableAttr("out_threshold"); + // set the out_threshold of the elementwise add op to be the out_threshold + // of the conv2d_fusion + if (out_threshold_attr.which()) { + new_op_desc.SetAttr("out_threshold", out_threshold_attr); + } new_op_desc.Flush(); // Create a new node for the fused op. diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc new file mode 100644 index 00000000000..8b3606b588a --- /dev/null +++ b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc @@ -0,0 +1,237 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h" + +#include +#include +#include +#include + +namespace paddle { +namespace framework { +namespace ir { + +#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); +#define GET_NODES \ + GET_IR_NODE(quant_dequant_op_x); \ + GET_IR_NODE(quant_dequant_op); \ + GET_IR_NODE(quant_dequant_op_out); \ + GET_IR_NODE(quant_dequant_op_outscale); \ + GET_IR_NODE(any_op2); + +// Delete quant_dequant_op, then quantize and dequantize weight +void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { + const std::string pattern_name = "delete_quantdequant_filter_op_pattern"; + FusePassBase::Init(pattern_name, graph); + + GraphPatternDetector gpd; + + // Create pattern + patterns::DeleteQuantDequantFilterOpPattern pattern(gpd.mutable_pattern(), + pattern_name); + pattern(); + auto* scope = param_scope(); + int found_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_NODES; + + std::unordered_set nodes2rm = {}; + int bit_length = + BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("bit_length")); + int range = ((1 << (bit_length - 1)) - 1); + std::vector weight_scale; + std::string quant_dequant_op_out_name = quant_dequant_op_out->Var()->Name(); + + auto* any_op2_desc = any_op2->Op(); + auto var_map = any_op2_desc->Inputs(); + std::string arg_name = ""; + for (auto& name_m : var_map) { + if (std::find(name_m.second.begin(), name_m.second.end(), + quant_dequant_op_out_name) != name_m.second.end()) { + arg_name = name_m.first; + break; + } + } + PADDLE_ENFORCE_GT(arg_name.size(), 0, platform::errors::InvalidArgument( + "can not find the input %s.", + quant_dequant_op_out_name)); + any_op2_desc->SetAttr("enable_int8", true); + any_op2_desc->SetAttr("bit_length", bit_length); + // modify the any_op2's inputs + any_op2_desc->Flush(); + auto dequant_type = quant_dequant_op->Op()->Type(); + auto quantized_op_type = any_op2_desc->Type(); + + // Get weight scale + if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") { + auto scales_name = quant_dequant_op->Op()->Output("OutScale"); + PADDLE_ENFORCE_EQ(scales_name.size(), 1, + platform::errors::InvalidArgument( + "Scales size in channel-wise quant dequantize op " + "should be 1, got %d.", + scales_name.size())); + const LoDTensor& channel_scale_tensor = + scope->GetVar(scales_name[0])->Get(); + PADDLE_ENFORCE( + paddle::platform::is_cpu_place(channel_scale_tensor.place()), + platform::errors::InvalidArgument( + "Channel scale tensor's place should be CPU.")); + const float* channel_scale_data = channel_scale_tensor.data(); + for (int i = 0; i < channel_scale_tensor.numel(); i++) { + weight_scale.push_back(range / channel_scale_data[i]); + } + } else { + auto scale_name = quant_dequant_op_outscale->Name(); + const LoDTensor& scale_tensor = + scope->GetVar(scale_name)->Get(); + const float* scale_data = scale_tensor.data(); + weight_scale.push_back((range * range) / scale_data[0] / range); + } + + nodes2rm.insert(quant_dequant_op_outscale); + // perform quantize dequantize operations + auto* weight_tensor = + scope->GetVar(quant_dequant_op_x->Name())->GetMutable(); + auto w_dims = weight_tensor->dims(); + float* quantized_weight_data = + weight_tensor->mutable_data(platform::CPUPlace()); + // If quantized op is fc, weight scale size = 1; + // If quantized op is conv2d, weight scale size = weight dims[0] + // If quantized op is conv2d_transpose, weight scale size = weight dims[1] + if (dequant_type == "fake_quantize_dequantize_abs_max") { + PADDLE_ENFORCE_EQ( + weight_scale.size(), 1, + platform::errors::InvalidArgument( + "%s op weight dequantized by [fake_quantize_dequantize_max_abs] " + "requires weight scale size = 1, but got %d.", + quantized_op_type, weight_scale.size())); + PADDLE_ENFORCE_NE(weight_scale[0], 0, + platform::errors::InvalidArgument( + "Weight scale should be nonzero, but get zero")); + for (int j = 0; j < weight_tensor->numel(); j++) { + // quantized + quantized_weight_data[j] = quantized_weight_data[j] * weight_scale[0]; + quantized_weight_data[j] = std::round(quantized_weight_data[j]); + // dequantized + quantized_weight_data[j] /= weight_scale[0]; + } + } else if (quantized_op_type == "mul" || quantized_op_type == "matmul" || + quantized_op_type == "fc") { + if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") { + PADDLE_ENFORCE_EQ( + weight_scale.size(), static_cast(w_dims[1]), + platform::errors::InvalidArgument( + "mul op weight dequantized by " + "[fake_channel_wise_quantize_dequantize_abs_max] requires " + "weight scale " + "size = 2nd dim of mul's weight, which is %zu, but got %zu.", + static_cast(w_dims[1]), weight_scale.size())); + for (int j = 0; j < weight_tensor->numel(); j++) { + // quantized + PADDLE_ENFORCE_NE( + weight_scale[j % w_dims[1]], 0, + platform::errors::InvalidArgument( + "fc op weight scale should be nonzero, but get zero")); + quantized_weight_data[j] = + quantized_weight_data[j] * weight_scale[j % w_dims[1]]; + quantized_weight_data[j] = std::round(quantized_weight_data[j]); + // dequantized + quantized_weight_data[j] /= weight_scale[j % w_dims[1]]; + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported quantized op type: %s", quantized_op_type)); + } + } else if (quantized_op_type == "conv2d" || + quantized_op_type == "depthwise_conv2d") { + if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") { + PADDLE_ENFORCE_EQ( + weight_scale.size(), static_cast(w_dims[0]), + platform::errors::InvalidArgument( + "conv2d op requires weight scale size = channel size of the " + "weight, which is %zu, but got %zu.", + static_cast(w_dims[0]), weight_scale.size())); + int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; + for (int j = 0; j < weight_tensor->numel(); j++) { + // quantized + PADDLE_ENFORCE_NE( + weight_scale[j / inner_size], 0, + platform::errors::InvalidArgument( + "conv2d op weight scale should be nonzero, but get zero")); + quantized_weight_data[j] = + quantized_weight_data[j] * weight_scale[j / inner_size]; + quantized_weight_data[j] = std::round(quantized_weight_data[j]); + // dequantized + quantized_weight_data[j] /= weight_scale[j / inner_size]; + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported quantized op type: %s", quantized_op_type)); + } + } else if (quantized_op_type == "conv2d_transpose") { + if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") { + PADDLE_ENFORCE_EQ( + weight_scale.size(), static_cast(w_dims[0]), + platform::errors::InvalidArgument( + "conv2d_transpose op requires weight scale size = channel size " + "of the " + "weight, which is %zu, but got %zu.", + static_cast(w_dims[1]), weight_scale.size())); + int inner_size = w_dims[2] * w_dims[3]; + for (int j = 0; j < weight_tensor->numel(); j++) { + // quantized + PADDLE_ENFORCE_NE(weight_scale[(j / inner_size) % w_dims[1]], 0, + platform::errors::InvalidArgument( + "conv2d_transpose op weight scale should be " + "nonzero, but get zero")); + quantized_weight_data[j] = quantized_weight_data[j] * + weight_scale[(j / inner_size) % w_dims[1]]; + quantized_weight_data[j] = std::round(quantized_weight_data[j]); + // dequantized + quantized_weight_data[j] /= + weight_scale[(j / inner_size) % w_dims[1]]; + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported quantized op type: %s", quantized_op_type)); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported quantized op type: %s", quantized_op_type)); + } + nodes2rm.insert(quant_dequant_op_out); + + // link weight in quant_dequant_op_x to any_op2 + any_op2_desc->RenameInput(quant_dequant_op_out->Var()->Name(), + quant_dequant_op_x->Var()->Name()); + any_op2_desc->SetAttr("weight_scale", weight_scale); + any_op2_desc->Flush(); + IR_NODE_LINK_TO(quant_dequant_op_x, any_op2); + nodes2rm.insert(quant_dequant_op); + GraphSafeRemoveNodes(graph, nodes2rm); + found_count++; + }; + gpd(graph, handler); + AddStatis(found_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(delete_quant_dequant_filter_op_pass, + paddle::framework::ir::DeleteQuantDequantFilterOpPass); diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h new file mode 100644 index 00000000000..0409032d938 --- /dev/null +++ b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h @@ -0,0 +1,37 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Graph; + +class DeleteQuantDequantFilterOpPass : public FusePassBase { + public: + virtual ~DeleteQuantDequantFilterOpPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc index 886b080c662..232b7c4c074 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc @@ -49,10 +49,10 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { std::string input_scale_var_name = quant_dequant_op->Op()->Input("InScale").front(); const LoDTensor& input_scale_tensor = - scope->FindVar(input_scale_var_name)->Get(); + scope->GetVar(input_scale_var_name)->Get(); const float* input_scale_data = input_scale_tensor.data(); - float input_scale = input_scale_data[0]; + float input_scale = input_scale_data[0] / 127.; auto* any_op2_desc = any_op2->Op(); // auto input_args_names = any_op2_desc->InputArgumentNames(); auto var_map = any_op2_desc->Inputs(); diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index 103fa0f5faf..2f646553614 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -149,6 +149,18 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { desc.SetAttr("out_scale", elementwise_desc->GetAttr("out_scale")); } + auto* elementwise_add_op_desc = elementwise_add->Op(); + // if we can find out_threshold in elementwise_add, then set it as the + // out_thrshold of fc + auto out_threshold_attr = + elementwise_add_op_desc->GetNullableAttr("out_threshold"); + if (out_threshold_attr.which()) { + VLOG(4) << "setting out_threshold: " + << BOOST_GET_CONST(float, out_threshold_attr); + desc.SetAttr("out_threshold", out_threshold_attr); + } + desc.Flush(); + auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. if (with_relu) { GraphSafeRemoveNodes( diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 22f6388597c..4efdfb01e38 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1634,6 +1634,27 @@ PDNode *patterns::MatmulWithInputOps::operator()() { return matmul_out; } +PDNode *patterns::Flatten2Matmul::operator()() { + auto flatten2_in_x = pattern->NewNode(flatten2_in_x_repr()) + ->assert_is_op_input("flatten2", "X") + ->AsInput(); + auto flatten2_op = + pattern->NewNode(flatten2_op_repr())->assert_is_op("flatten2"); + auto matmul_in_x = pattern->NewNode(matmul_in_x_repr()) + ->assert_is_op_output("flatten2", "Out") + ->assert_is_op_input("matmul", "X"); + auto matmul_in_y = + pattern->NewNode(matmul_in_y_repr())->assert_is_op_input("matmul", "Y"); + auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul"); + auto matmul_out = pattern->NewNode(matmul_out_repr()) + ->AsOutput() + ->assert_is_op_output("matmul", "Out"); + + flatten2_op->LinksFrom({flatten2_in_x}).LinksTo({matmul_in_x}); + matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({matmul_out}); + return matmul_out; +} + PDNode *patterns::ConvResidual::operator()(bool with_residual_data) { auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); @@ -2495,6 +2516,43 @@ void patterns::DeleteQuantDequantOpPattern::operator()() { any_op2->LinksFrom({quant_dequant_out}); } +void patterns::DeleteQuantDequantFilterOpPattern::operator()() { + auto quant_dequant_op_x = + pattern->NewNode(quant_dequant_op_x_repr()) + ->assert_is_ops_input( + {"fake_channel_wise_quantize_dequantize_abs_max", + "fake_quantize_dequantize_abs_max"}, + "X") + ->AsInput(); + + auto quant_dequant_op = + pattern->NewNode(quant_dequant_op_repr()) + ->assert_is_ops({"fake_channel_wise_quantize_dequantize_abs_max", + "fake_quantize_dequantize_abs_max"}); + + auto quant_dequant_out = + pattern->NewNode(quant_dequant_op_out_repr()) + ->assert_is_ops_output( + {"fake_channel_wise_quantize_dequantize_abs_max", + "fake_quantize_dequantize_abs_max"}, + "Out") + ->AsIntermediate(); + + auto quant_dequant_op_outscale = + pattern->NewNode(quant_dequant_op_outscale_repr()) + ->assert_is_ops_output( + {"fake_channel_wise_quantize_dequantize_abs_max", + "fake_quantize_dequantize_abs_max"}, + "OutScale") + ->AsOutput(); + auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput(); + + quant_dequant_op->LinksFrom({quant_dequant_op_x}); + quant_dequant_op_outscale->LinksFrom({quant_dequant_op}); + quant_dequant_out->LinksFrom({quant_dequant_op}); + any_op2->LinksFrom({quant_dequant_out}); +} + PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( bool with_reshape_xshape, bool with_transpose_xshape) { auto reshape_op = diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 83feaa3a4bf..8167400221c 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -996,6 +996,21 @@ struct MatmulWithInputOps : public PatternBase { PATTERN_DECL_NODE(matmul_out); }; +// Flatten2 + Matmul +// Forward pass. +struct Flatten2Matmul : public PatternBase { + Flatten2Matmul(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "flatten2_matmul") {} + + PDNode* operator()(); + PATTERN_DECL_NODE(flatten2_in_x); + PATTERN_DECL_NODE(flatten2_op); + PATTERN_DECL_NODE(matmul_in_x); + PATTERN_DECL_NODE(matmul_in_y); + PATTERN_DECL_NODE(matmul_op); + PATTERN_DECL_NODE(matmul_out); +}; + // Concat op // Forward pass for concat. // concat_out is a result of the operator. @@ -1426,6 +1441,21 @@ struct DeleteQuantDequantOpPattern : public PatternBase { PATTERN_DECL_NODE(any_op2); }; +struct DeleteQuantDequantFilterOpPattern : public PatternBase { + DeleteQuantDequantFilterOpPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, + "delete_quantdequant_filter_op_pattern") {} + + void operator()(); + + PATTERN_DECL_NODE(quant_dequant_op_x); + PATTERN_DECL_NODE(quant_dequant_op); + PATTERN_DECL_NODE(quant_dequant_op_outscale); + PATTERN_DECL_NODE(quant_dequant_op_out); + PATTERN_DECL_NODE(any_op2); +}; + // Reshape + Transpose + Matmul // named nodes: // reshape_op, reshape_out, reshape_xshape, diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc index 8c4e6f33058..d86fb5c9ccc 100644 --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc @@ -71,7 +71,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { desc.SetOutput("Out", {matmul_out->Name()}); desc.SetAttr("x_num_col_dims", 1); desc.SetAttr("y_num_col_dims", 1); - + if (matmul_op->Op()->HasAttr("enable_int8")) { + desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); + desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); + desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + } auto mul_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(matmul_in_x, mul_node); IR_NODE_LINK_TO(matmul_in_y, mul_node); @@ -137,7 +141,11 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetOutput("Out", {matmul_out->Name()}); desc.SetAttr("x_num_col_dims", 1); desc.SetAttr("y_num_col_dims", 1); - + if (matmul_op->Op()->HasAttr("enable_int8")) { + desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); + desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); + desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + } auto mul_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(squeeze2_in_x, mul_node); IR_NODE_LINK_TO(matmul_in_y, mul_node); @@ -205,7 +213,11 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetOutput("Out", {matmul_out->Name()}); desc.SetAttr("x_num_col_dims", 1); desc.SetAttr("y_num_col_dims", 1); - + if (matmul_op->Op()->HasAttr("enable_int8")) { + desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); + desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); + desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + } auto mul_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(reshape2_in_x, mul_node); IR_NODE_LINK_TO(matmul_in_y, mul_node); @@ -219,6 +231,83 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_count); } +void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + std::string name_scope = "flatten2_matmul_fuse_pass"; + FusePassBase::Init(name_scope, graph); + + GraphPatternDetector gpd; + patterns::Flatten2Matmul fuse_pattern(gpd.mutable_pattern(), name_scope); + fuse_pattern(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "fuse flatten2+matmul to mul"; + GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, fuse_pattern); + bool pattern_found = true; + + size_t flatten2_in_nums = flatten2_op->inputs.size(); + auto flatten2_in_x_shape = flatten2_in_x->Var()->GetShape(); + size_t flatten2_in_x_rank = flatten2_in_x_shape.size(); + int flatten2_axis = + BOOST_GET_CONST(int, flatten2_op->Op()->GetAttr("axis")); + // only convert matmul to mul when the flatten2 has a single input + // and the rank of input is 4 and the size of the output of matmul + // is 1. + pattern_found = pattern_found && flatten2_in_nums == 1 && + flatten2_in_x_rank == 4 && + (matmul_in_x->outputs).size() == 1; + + bool transpose_X = + BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_X")); + bool transpose_Y = + BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_Y")); + float alpha = BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha")); + size_t matmul_in_x_rank = (matmul_in_x->Var()->GetShape()).size(); + size_t matmul_in_y_rank = (matmul_in_y->Var()->GetShape()).size(); + pattern_found = pattern_found && !transpose_X && !transpose_Y && + std::abs(alpha - 1.0) < 1e-5 && matmul_in_x_rank == 2 && + matmul_in_y_rank == 2; + + std::vector& next_ops = matmul_out->outputs; + // we further require the matmul op is followed by one elementwise + // add op. + pattern_found = pattern_found && next_ops.size() == 1 && + next_ops[0]->Name() == "elementwise_add"; + + if (pattern_found) { + OpDesc desc; + desc.SetType("mul"); + desc.SetInput("X", {flatten2_in_x->Name()}); + desc.SetInput("Y", {matmul_in_y->Name()}); + desc.SetOutput("Out", {matmul_out->Name()}); + desc.SetAttr("x_num_col_dims", flatten2_axis); + desc.SetAttr("y_num_col_dims", 1); + if (matmul_op->Op()->HasAttr("enable_int8")) { + desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); + desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); + desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + } + auto mul_node = g->CreateOpNode(&desc); + IR_NODE_LINK_TO(flatten2_in_x, mul_node); + IR_NODE_LINK_TO(matmul_in_y, mul_node); + IR_NODE_LINK_TO(mul_node, matmul_out); + GraphSafeRemoveNodes(graph, {flatten2_op, matmul_in_x, matmul_op}); + ++found_count; + } + }; + + gpd(graph, handler); + AddStatis(found_count); +} + } // namespace ir } // namespace framework } // namespace paddle @@ -247,3 +336,12 @@ REGISTER_PASS_CAPABILITY(reshape2_matmul_fuse_pass) .LE("matmul", 1) .EQ("reshape2", 0) .EQ("mul", 0)); + +REGISTER_PASS(flatten2_matmul_fuse_pass, + paddle::framework::ir::Flatten2MatmulFusePass); +REGISTER_PASS_CAPABILITY(flatten2_matmul_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("matmul", 1) + .EQ("flatten2", 0) + .EQ("mul", 0)); diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h index 1c89c97f96e..85067a6f642 100644 --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h @@ -101,6 +101,14 @@ class Reshape2MatmulFusePass : public FusePassBase { void ApplyImpl(Graph* graph) const override; }; +class Flatten2MatmulFusePass : public FusePassBase { + public: + virtual ~Flatten2MatmulFusePass() {} + + protected: + void ApplyImpl(Graph* graph) const override; +}; + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index 45e4c3edb05..5a83fed2d0f 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -83,6 +83,13 @@ Variable* Scope::FindVar(const std::string& name) const { return FindVarInternal(name); } +Variable* Scope::GetVar(const std::string& name) const { + auto* var = FindVar(name); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::NotFound("Cannot find %s in scope.", name)); + return var; +} + Variable* Scope::FindLocalVar(const std::string& name) const { SCOPE_VARS_READER_LOCK return FindVarLocally(name); diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index 922e9a9b272..bab57e529df 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -81,6 +81,10 @@ class Scope { /// Caller doesn't own the returned Variable. Variable* FindVar(const std::string& name) const; + // Get a variable in the scope or any of its ancestors. Enforce + /// the returned Variable is not nullptr + Variable* GetVar(const std::string& name) const; + /// Find a variable in the current scope. /// Return nullptr if cannot find. /// Caller doesn't own the returned Variable. diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index fcef2a5cbc9..7c6ce00d5d6 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -345,7 +345,7 @@ void AnalysisConfig::Update() { pass_builder()->ClearPasses(); for (const auto &pass : kTRTSubgraphPasses) { if (tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 && - (pass == "conv_bn_fuse_pass" || pass == "fc_fuse_pass")) { + (pass == "conv_bn_fuse_pass")) { continue; } pass_builder()->AppendPass(pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 4b0bfbe2af1..2e9b8e0d145 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -77,6 +77,7 @@ const std::vector kTRTSubgraphPasses({ "shuffle_channel_detect_pass", // "quant_conv2d_dequant_fuse_pass", // "delete_quant_dequant_op_pass", // + "delete_quant_dequant_filter_op_pass", // // "fc_fuse_pass", // "simplify_with_basic_ops_pass", // "embedding_eltwise_layernorm_fuse_pass", // @@ -86,15 +87,16 @@ const std::vector kTRTSubgraphPasses({ "conv_bn_fuse_pass", // "squeeze2_matmul_fuse_pass", // "reshape2_matmul_fuse_pass", // + "flatten2_matmul_fuse_pass", // "map_matmul_to_mul_pass", // "fc_fuse_pass", // + "conv_elementwise_add_fuse_pass", // "tensorrt_subgraph_pass", // "conv_bn_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be // guaranteed at least v7 "conv_elementwise_add_act_fuse_pass", // "conv_elementwise_add2_act_fuse_pass", // - "conv_elementwise_add_fuse_pass", // #endif // "transpose_flatten_concat_fuse_pass", }); @@ -118,6 +120,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "multihead_matmul_fuse_pass_v2", // "squeeze2_matmul_fuse_pass", // "reshape2_matmul_fuse_pass", // + "flatten2_matmul_fuse_pass", // "map_matmul_to_mul_pass", // "fc_fuse_pass", // "fc_elementwise_layernorm_fuse_pass", // @@ -172,6 +175,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { "seq_concat_fc_fuse_pass", // "squeeze2_matmul_fuse_pass", // "reshape2_matmul_fuse_pass", // + "flatten2_matmul_fuse_pass", // "map_matmul_to_mul_pass", // "fc_fuse_pass", // "repeated_fc_relu_fuse_pass", // diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index f582d7e0705..17652afe771 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -105,8 +105,18 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, static_cast(weight_data), static_cast(Y_t->numel())}; + float* bias_data = nullptr; + size_t bias_size = 0; + if (op_desc.Type() == "conv2d_fusion") { + auto* bias_tensor = scope.GetVar(op_desc.Input("Bias").front()); + auto* bias_tensor_data = bias_tensor->GetMutable(); + bias_data = engine->GetWeightCPUData(op_desc.Input("Bias").front(), + bias_tensor_data, false); + bias_size = static_cast(bias_tensor_data->numel()); + } - TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; + TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, + static_cast(bias_data), bias_size}; auto* layer = fadd_layer(const_cast(X), n_output, n_input, nv_ksize, weight, bias); PADDLE_ENFORCE_NOT_NULL(layer, @@ -184,4 +194,5 @@ class Deconv2dOpConverter : public OpConverter { } // namespace paddle REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter); +REGISTER_TRT_OP_CONVERTER(conv2d_fusion, Conv2dOpConverter); REGISTER_TRT_OP_CONVERTER(conv2d_transpose, Deconv2dOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index cd16ed73965..9ef027b1c2e 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -67,10 +67,11 @@ class FcOpConverter : public OpConverter { // assigned from CPU memory, which can't be avoided. float* weight_data = nullptr; bool enable_int8 = op_desc.HasAttr("enable_int8"); + float in_scale = 0.; if (enable_int8) { #if IS_TRT_VERSION_GE(5000) CHECK(op_desc.HasAttr(i_name + "_scale")); - float in_scale = + in_scale = BOOST_GET_CONST(float, op_desc.GetAttr(i_name + "_scale")) * 127; auto weight_scale = BOOST_GET_CONST(std::vector, op_desc.GetAttr("weight_scale")); @@ -131,7 +132,7 @@ class FcOpConverter : public OpConverter { float* bias_data = nullptr; int bias_num = 0; if (with_bias) { - auto* b_v = scope.FindVar(op_desc.Input("Bias").front()); + auto* b_v = scope.GetVar(op_desc.Input("Bias").front()); auto* b_t = b_v->GetMutable(); bias_data = engine_->GetWeightCPUData(op_desc.Input("Bias").front(), b_t, false); @@ -183,6 +184,9 @@ class FcOpConverter : public OpConverter { auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X); reshape_layer->setReshapeDimensions(reshape_dim); reshape_itensor = reshape_layer->getOutput(0); + if (enable_int8) { + engine_->SetTensorDynamicRange(reshape_itensor, in_scale); + } } else { PADDLE_ENFORCE_NE(input_dims, 1, platform::errors::InvalidArgument( @@ -200,6 +204,9 @@ class FcOpConverter : public OpConverter { auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X); reshape_layer->setReshapeDimensions(reshape_dim); reshape_itensor = reshape_layer->getOutput(0); + if (enable_int8) { + engine_->SetTensorDynamicRange(reshape_itensor, in_scale); + } } regist_fc(reshape_itensor, n_output, weight, bias); } diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 307f727efe9..821fdeddc98 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -58,6 +58,7 @@ struct SimpleOpTypeSetTeller : public Teller { // use this set for no calib int8. std::unordered_set int8_teller_set{"mul", "conv2d", + "conv2d_fusion", "pool2d", "relu", "depthwise_conv2d", @@ -76,6 +77,7 @@ struct SimpleOpTypeSetTeller : public Teller { "mul", "matmul", "conv2d", + "conv2d_fusion", "pool2d", "relu", "softmax", -- GitLab