From 1b58ce144a340ff895dedeeab68e3a3a3ab36c06 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Sat, 2 Apr 2022 17:23:43 +0800 Subject: [PATCH] [Paddle inference] support new quant_model (#41049) * paddle inference support new quant_model --- paddle/fluid/framework/ir/CMakeLists.txt | 2 + .../framework/ir/add_support_int8_pass.cc | 61 ++- .../ir/delete_quant_dequant_linear_op_pass.cc | 148 +++++++ .../ir/delete_quant_dequant_linear_op_pass.h | 35 ++ .../ir/delete_quant_dequant_op_pass.cc | 10 +- .../delete_weight_dequant_linear_op_pass.cc | 415 ++++++++++++++++++ .../ir/delete_weight_dequant_linear_op_pass.h | 35 ++ paddle/fluid/framework/ir/fc_fuse_pass.cc | 29 +- .../ir/gpu_cpu_map_matmul_to_mul_pass.cc | 19 +- .../framework/ir/graph_pattern_detector.cc | 101 ++++- .../framework/ir/graph_pattern_detector.h | 36 +- .../ir/multihead_matmul_fuse_pass.cc | 51 +-- .../ir/quant_conv2d_dequant_fuse_pass.cc | 11 +- .../ir/trt_map_matmul_to_mul_pass.cc | 101 ++++- .../inference/api/paddle_pass_builder.cc | 16 +- .../tensorrt/convert/activation_op.cc | 6 - .../tensorrt/convert/affine_channel_op.cc | 4 +- .../inference/tensorrt/convert/conv2d_op.cc | 13 +- .../inference/tensorrt/convert/conv3d_op.cc | 11 +- .../tensorrt/convert/deformable_conv_op.cc | 3 +- .../tensorrt/convert/elementwise_op.cc | 20 +- .../tensorrt/convert/emb_eltwise_layernorm.cc | 2 +- .../fluid/inference/tensorrt/convert/fc_op.cc | 60 +-- .../tensorrt/convert/group_norm_op.cc | 2 +- .../tensorrt/convert/leaky_relu_op.cc | 4 +- .../inference/tensorrt/convert/matmul_op.cc | 4 +- .../tensorrt/convert/multihead_matmul_op.cc | 46 +- .../inference/tensorrt/convert/op_converter.h | 88 ++-- .../inference/tensorrt/convert/pool2d_op.cc | 7 +- .../inference/tensorrt/convert/pool3d_op.cc | 5 +- .../convert/preln_emb_eltwise_layernorm.cc | 2 +- .../tensorrt/convert/preln_skip_layernorm.cc | 2 +- .../inference/tensorrt/convert/prelu_op.cc | 4 +- .../tensorrt/convert/skip_layernorm.cc | 2 +- paddle/fluid/inference/tensorrt/engine.cc | 4 +- paddle/fluid/inference/tensorrt/engine.h | 3 +- .../operators/compat/dequantize_linear.pbtxt | 25 ++ paddle/fluid/operators/compat/mul.pbtxt | 10 +- .../operators/compat/quantize_linear.pbtxt | 25 ++ .../test_trt_convert_multihead_matmul.py | 9 +- 40 files changed, 1146 insertions(+), 285 deletions(-) create mode 100644 paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc create mode 100644 paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.h create mode 100644 paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc create mode 100644 paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h create mode 100644 paddle/fluid/operators/compat/dequantize_linear.pbtxt create mode 100644 paddle/fluid/operators/compat/quantize_linear.pbtxt diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 7aaaef712a6..8cacf34834a 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -86,6 +86,8 @@ pass_library(quant_conv2d_dequant_fuse_pass inference) pass_library(shuffle_channel_detect_pass inference) pass_library(delete_quant_dequant_op_pass inference) pass_library(delete_quant_dequant_filter_op_pass inference) +pass_library(delete_weight_dequant_linear_op_pass inference) +pass_library(delete_quant_dequant_linear_op_pass inference) pass_library(delete_dropout_op_pass inference) pass_library(simplify_with_basic_ops_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base) diff --git a/paddle/fluid/framework/ir/add_support_int8_pass.cc b/paddle/fluid/framework/ir/add_support_int8_pass.cc index d157d2e934a..3a3f5c3741f 100644 --- a/paddle/fluid/framework/ir/add_support_int8_pass.cc +++ b/paddle/fluid/framework/ir/add_support_int8_pass.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,11 +19,7 @@ namespace framework { namespace ir { #define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); -#define GET_NODES \ - GET_IR_NODE(prev_op); \ - GET_IR_NODE(prev_out); \ - GET_IR_NODE(quant_op); \ - GET_IR_NODE(quant_out); +#define GET_NODES GET_IR_NODE(quant_op); void AddSupportInt8Pass::ApplyImpl(ir::Graph* graph) const { const std::string pattern_name = "add_support_int8"; @@ -37,10 +33,57 @@ void AddSupportInt8Pass::ApplyImpl(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { GET_NODES; - if (prev_op->Op()->HasAttr("out_threshold") && - quant_op->Op()->HasAttr("out_threshold")) { - quant_op->Op()->SetAttr("support_int8", true); + + bool inscale_flag = false; + bool outscale_flag = false; + auto* quanted_op_desc = quant_op->Op(); + // If inputs'tensors have the inputs_scale, then save it's index in + // input_quant_tensor_index + // OP'Attr hasn't std::vector>. To do: Support multi-tensor + // scale for one input + for (size_t i = 0; i < quanted_op_desc->InputNames().size(); i++) { + if (quanted_op_desc->Input(quanted_op_desc->InputNames()[i]).size() > 0 && + quanted_op_desc->HasAttr( + "Input_scale_" + + quanted_op_desc->Input(quanted_op_desc->InputNames()[i])[0])) { + inscale_flag = true; + quanted_op_desc->SetAttr( + quanted_op_desc->InputNames()[i], + quanted_op_desc->GetAttr( + "Input_scale_" + + quanted_op_desc->Input(quanted_op_desc->InputNames()[i])[0])); + } + } + + // If outputs'tensors have the outputs_scale, then save it's index in + // output_quant_tensor_index + // OP'Attr hasn't std::vector>. To do: Support multi-tensor + // scale for one output + for (auto out_node : quant_op->outputs) { + for (auto out_op_node : out_node->outputs) { + for (auto name : out_op_node->Op()->InputNames()) { + for (auto input_name : out_op_node->Op()->Input(name)) { + if (out_op_node->Op()->HasAttr("Input_scale_" + input_name)) { + for (size_t i = 0; i < quanted_op_desc->OutputNames().size(); + i++) { + if (quanted_op_desc->Output(quanted_op_desc->OutputNames()[i]) + .size() > 0 && + input_name == + quanted_op_desc->Output( + quanted_op_desc->OutputNames()[i])[0]) { + outscale_flag = true; + quanted_op_desc->SetAttr( + quanted_op_desc->OutputNames()[i], + out_op_node->Op()->GetAttr("Input_scale_" + input_name)); + } + } + } + } + } + } } + quanted_op_desc->SetAttr("support_int8", inscale_flag && outscale_flag); + quanted_op_desc->Flush(); found_count++; }; gpd(graph, handler); diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc new file mode 100644 index 00000000000..8f2b58ed51b --- /dev/null +++ b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc @@ -0,0 +1,148 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.h" + +#include +#include +#include +#include +#include + +namespace paddle { +namespace framework { +namespace ir { + +#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); +#define GET_NODES \ + GET_IR_NODE(quantize_linear_op_x); \ + GET_IR_NODE(quantize_linear_op_scale); \ + GET_IR_NODE(quantize_linear_op); \ + GET_IR_NODE(quantize_linear_op_out); \ + GET_IR_NODE(dequantize_linear_op); \ + GET_IR_NODE(dequantize_linear_op_out); \ + GET_IR_NODE(any_op2); + +DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() { + AddOpCompat(OpCompat("quantize_linear")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("ZeroPoint") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddAttr("bit_length") + .IsType() + .End() + .AddAttr("quant_axis") + .IsType() + .End(); + AddOpCompat(OpCompat("dequantize_linear")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("ZeroPoint") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddAttr("bit_length") + .IsType() + .End() + .AddAttr("quant_axis") + .IsType() + .End(); +} +// Delete quantize_linear_op dequantize_linear_op, then add input_scales +void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { + const std::string pattern_name = "delete_quantdequant_linear_op_pattern"; + FusePassBase::Init(pattern_name, graph); + + GraphPatternDetector gpd; + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::InvalidArgument( + "Scope in DeleteQuantDequantLinearOpPass should not be null.")); + // Create pattern + patterns::DeleteQuantDequantLinearOpPattern pattern(gpd.mutable_pattern(), + pattern_name); + pattern(); + int found_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_NODES; + /* + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "delete_quant_dequant_linear_op_pass " + "compat check failed."; + return; + } + */ + std::unordered_set nodes2rm = {}; + int bit_length = + BOOST_GET_CONST(int, quantize_linear_op->Op()->GetAttr("bit_length")); + int range = ((1 << (bit_length - 1)) - 1); + + // Get input scale from tensor + const LoDTensor& input_scale_tensor = + scope->GetVar(quantize_linear_op_scale->Name())->Get(); + PADDLE_ENFORCE_EQ( + paddle::platform::is_cpu_place(input_scale_tensor.place()), true, + platform::errors::InvalidArgument( + "Input scale tensor's place should be CPU.")); + const float* input_scale_data = input_scale_tensor.data(); + float input_scale = input_scale_data[0] / range; + + auto* any_op2_desc = any_op2->Op(); + any_op2_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(), + input_scale); + + nodes2rm.insert(quantize_linear_op_scale); + nodes2rm.insert(quantize_linear_op); + nodes2rm.insert(quantize_linear_op_out); + nodes2rm.insert(dequantize_linear_op); + nodes2rm.insert(dequantize_linear_op_out); + + // link x to any_op2 + any_op2_desc->RenameInput(dequantize_linear_op_out->Var()->Name(), + quantize_linear_op_x->Var()->Name()); + any_op2_desc->Flush(); + IR_NODE_LINK_TO(quantize_linear_op_x, any_op2); + GraphSafeRemoveNodes(graph, nodes2rm); + found_count++; + }; + gpd(graph, handler); + AddStatis(found_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(delete_quant_dequant_linear_op_pass, + paddle::framework::ir::DeleteQuantDequantLinearOpPass); diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.h b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.h new file mode 100644 index 00000000000..b00e3cb5c46 --- /dev/null +++ b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class DeleteQuantDequantLinearOpPass : public FusePassBase { + public: + DeleteQuantDequantLinearOpPass(); + virtual ~DeleteQuantDequantLinearOpPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc index 63d68bd04b5..e2bb62dba7c 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc @@ -61,7 +61,6 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { GET_NODES; int bit_length = BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("bit_length")); - int range = ((1 << (bit_length - 1)) - 1); // Get input scale from tensor std::string input_scale_var_name = @@ -76,7 +75,7 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { platform::errors::InvalidArgument( "Input scale tensor's place should be CPU.")); const float* input_scale_data = input_scale_tensor.data(); - float input_scale = input_scale_data[0] / range; + float input_scale = input_scale_data[0]; // Set input scale in attr, and relink nodes std::string input_name = input->Var()->Name(); @@ -85,12 +84,7 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { for (auto* quantized_node : outlinks) { auto op_desc = quantized_node->Op(); std::string quantized_op_type = op_desc->Type(); - if (quantized_op_type == "mul" || quantized_op_type == "matmul" || - quantized_op_type == "matmul_v2") { - op_desc->SetAttr("X_scale", input_scale); - } else { - op_desc->SetAttr("Input_scale", input_scale); - } + op_desc->SetAttr("Input_scale", input_scale); op_desc->SetAttr("bit_length", bit_length); op_desc->RenameInput(quant_dequant_output_name, input_name); op_desc->Flush(); diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc new file mode 100644 index 00000000000..8ebea231e7a --- /dev/null +++ b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc @@ -0,0 +1,415 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h" + +#include +#include +#include +#include +#include + +namespace paddle { +namespace framework { +namespace ir { + +#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); +#define GET_NODES \ + GET_IR_NODE(weight_dequantize_linear_op_x); \ + GET_IR_NODE(weight_dequantize_linear_op_scale); \ + GET_IR_NODE(weight_dequantize_linear_op); \ + GET_IR_NODE(weight_dequantize_linear_op_out); \ + GET_IR_NODE(any_op2); + +DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() { + AddOpCompat(OpCompat("quantize_linear")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("ZeroPoint") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddAttr("bit_length") + .IsType() + .End() + .AddAttr("quant_axis") + .IsType() + .End(); + AddOpCompat(OpCompat("dequantize_linear")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("ZeroPoint") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddAttr("bit_length") + .IsType() + .End() + .AddAttr("quant_axis") + .IsType() + .End(); + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddInput("ResidualData") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + AddOpCompat(OpCompat("depthwise_conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddInput("ResidualData") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsBoolEQ(false) + .End() + .AddAttr("trans_y") + .IsBoolEQ(false) + .End(); + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.99f) + .IsNumLE(1.01f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsBoolEQ(false) + .End(); + AddOpCompat(OpCompat("fc")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("W") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("in_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("activation_type") + .IsStringIn({"relu", ""}) + .End(); + AddOpCompat(OpCompat("conv2d_transpose")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("output_padding") + .IsType>() + .IsOptional() + .End() + .AddAttr("output_size") + .IsType>() + .IsOptional() + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); +} +// Delete dequantize_linear_op, then dequantize weight +void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { + const std::string pattern_name = + "delete_weight_quantdequant_linear_op_pattern"; + FusePassBase::Init(pattern_name, graph); + + GraphPatternDetector gpd; + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::InvalidArgument( + "Scope in DeleteWeightQuantDequantLinearOpPass should not be null.")); + // Create pattern + patterns::DeleteWeightQuantDequantLinearOpPattern pattern( + gpd.mutable_pattern(), pattern_name); + pattern(); + int found_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_NODES; + /* + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "delete_weight_dequant_linear_op_pass " + "compat check failed."; + return; + } + */ + std::unordered_set nodes2rm = {}; + int bit_length = BOOST_GET_CONST( + int, weight_dequantize_linear_op->Op()->GetAttr("bit_length")); + int range = ((1 << (bit_length - 1)) - 1); + + auto* any_op2_desc = any_op2->Op(); + + // get weight tensor + auto* weight_tensor = scope->GetVar(weight_dequantize_linear_op_x->Name()) + ->GetMutable(); + int8_t* quantized_weight_data = + weight_tensor->mutable_data(platform::CPUPlace()); + auto w_dims = weight_tensor->dims(); + + // Get weight scale + std::vector weight_scale; + auto* weight_scale_tensor = + scope->GetVar(weight_dequantize_linear_op_scale->Name()) + ->GetMutable(); + float* weight_scale_data = + weight_scale_tensor->mutable_data(platform::CPUPlace()); + + auto weight_scale_nums = weight_scale_tensor->numel(); + for (int i = 0; i < weight_scale_nums; i++) { + weight_scale.push_back(weight_scale_data[i] / range); + } + + // dequant weight + std::vector weight_data_tmp; + weight_data_tmp.reserve(weight_tensor->numel()); + + int quant_axis = BOOST_GET_CONST( + int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis")); + if (quant_axis == -1) { // per_layer quant_dequant: all OP + PADDLE_ENFORCE_EQ(weight_scale_nums, 1, + platform::errors::InvalidArgument( + "When quant_axis == -1 means use per_layer " + "quant_dequant, weight_scale'number should be 1.")); + + // float(weight) * scale + for (int i = 0; i < weight_tensor->numel(); i++) { + weight_data_tmp[i] = + static_cast(quantized_weight_data[i]) * weight_scale[0]; + } + } else if (quant_axis == 0) { // per_channel quant_dequant: conv2d, + // depthwise_conv2d, conv2d_fusion + PADDLE_ENFORCE_EQ( + weight_scale_nums, w_dims[quant_axis], + platform::errors::InvalidArgument( + "When quant_axis == 0 means use per_channel quant_dequant, " + "weight_scale'numbers should be equal channels.")); + PADDLE_ENFORCE_EQ(w_dims.size(), 4, + platform::errors::InvalidArgument( + "When quant_axis == 0 means use per_channel " + "quant_dequant, (conv2d, depthwise_conv2d, " + "conv2d_fusion)'s weight dims should be 4.")); + + for (int i = 0; i < weight_tensor->numel(); i++) { + int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; + weight_data_tmp[i] = static_cast(quantized_weight_data[i]) * + weight_scale[i / inner_size]; + } + } else if (quant_axis == 1) { + PADDLE_ENFORCE_EQ( + weight_scale_nums, w_dims[quant_axis], + platform::errors::InvalidArgument( + "When quant_axis == 1 means use per_channel quant_dequant, " + "weight_scale'numbers should be equal channels.")); + + if (w_dims.size() == 4) { // conv2d_transpose + std::string quantized_op_type = any_op2->Op()->Type(); + PADDLE_ENFORCE_EQ( + quantized_op_type, "conv2d_transpose", + platform::errors::InvalidArgument( + "When quant_axis == 1 means use per_channel quant_dequant, " + "only conv2d_transpose weight dims equal 4.")); + for (int i = 0; i < weight_tensor->numel(); i++) { + int inner_size = w_dims[2] * w_dims[3]; + weight_data_tmp[i] = static_cast(quantized_weight_data[i]) * + weight_scale[(i / inner_size) % w_dims[1]]; + } + } else if (w_dims.size() == 2) { + for (int i = 0; i < weight_tensor->numel(); i++) { + weight_data_tmp[i] = static_cast(quantized_weight_data[i]) * + weight_scale[i % w_dims[1]]; + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "When quant_axis == 1 , weight dims should be 2 or 4, please check " + "your model ")); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "quant_axis should be -1 or 0 or 1, please check your model " + "OP'attribute ")); + } + weight_tensor->clear(); // clear int weight + weight_tensor->Resize(phi::make_ddim(phi::vectorize(w_dims))); + float* new_quantized_weight_data = + weight_tensor->mutable_data(platform::CPUPlace()); + memcpy(new_quantized_weight_data, weight_data_tmp.data(), + weight_tensor->numel() * sizeof(float)); + + nodes2rm.insert(weight_dequantize_linear_op_scale); + nodes2rm.insert(weight_dequantize_linear_op); + nodes2rm.insert(weight_dequantize_linear_op_out); + + // relink weight to any_op2 + any_op2_desc->RenameInput(weight_dequantize_linear_op_out->Var()->Name(), + weight_dequantize_linear_op_x->Var()->Name()); + any_op2_desc->Flush(); + IR_NODE_LINK_TO(weight_dequantize_linear_op_x, any_op2); + GraphSafeRemoveNodes(graph, nodes2rm); + found_count++; + }; + gpd(graph, handler); + AddStatis(found_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(delete_weight_dequant_linear_op_pass, + paddle::framework::ir::DeleteWeightQuantDequantLinearOpPass); diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h new file mode 100644 index 00000000000..e240b6212b8 --- /dev/null +++ b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class DeleteWeightQuantDequantLinearOpPass : public FusePassBase { + public: + DeleteWeightQuantDequantLinearOpPass(); + virtual ~DeleteWeightQuantDequantLinearOpPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index e246a10961c..1e25b21483b 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -226,23 +226,34 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { // For anakin subgraph int8 // When in anakin subgraph int8 mode, the pattern like "fake_quant + mul + // fake_dequant" can be detected by the quant_dequant_fuse_pass. This pass - // will add "input_scale", "weight_scale" which are extracted from + // will add "input_scale" which are extracted from // fake_quant op and fake_dequant op to mul op, and then delete the // fake_quant op and fake_dequant op in the graph. If the mul op has the // scale info, we should add those to the fused fc. auto* mul_op_desc = mul->Op(); + auto* elementwise_add_op_desc = elementwise_add->Op(); + if (mul_op_desc->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", mul_op_desc->GetAttr("enable_int8")); - desc.SetAttr("Input_scale", mul_op_desc->GetAttr("X_scale")); - desc.SetAttr("weight_scale", mul_op_desc->GetAttr("weight_scale")); - if (mul_op_desc->HasAttr("out_scale")) - desc.SetAttr("out_scale", mul_op_desc->GetAttr("out_scale")); - auto elementwise_desc = elementwise_add->Op(); - if (elementwise_desc->HasAttr("out_scale")) - desc.SetAttr("out_scale", elementwise_desc->GetAttr("out_scale")); } - auto* elementwise_add_op_desc = elementwise_add->Op(); + if (mul_op_desc->HasAttr("Input_scale")) { + desc.SetAttr("Input_scale", mul_op_desc->GetAttr("Input_scale")); + } + + bool inscale_flag = false; + bool outscale_flag = false; + + if (mul_op_desc->HasAttr("X")) { + desc.SetAttr("X", mul_op_desc->GetAttr("X")); + inscale_flag = true; + } + if (elementwise_add_op_desc->HasAttr("Out")) { + desc.SetAttr("Out", elementwise_add_op_desc->GetAttr("Out")); + outscale_flag = true; + } + desc.SetAttr("support_int8", inscale_flag && outscale_flag); + // if we can find out_threshold in elementwise_add, then set it as the // out_thrshold of fc auto out_threshold_attr = diff --git a/paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.cc index 1759d18761d..ac580b99b5c 100644 --- a/paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.cc +++ b/paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.cc @@ -298,8 +298,7 @@ void GpuCpuMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("y_num_col_dims", 1); if (matmul_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_op->Op()->GetAttr("out_threshold")); } @@ -372,9 +371,7 @@ void GpuCpuMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("y_num_col_dims", 1); if (matmul_v2_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", - matmul_v2_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_v2_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_v2_op->Op()->GetAttr("out_threshold")); } @@ -451,8 +448,7 @@ void GpuCpuMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { } if (matmul_v2_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_v2_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_v2_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_v2_op->Op()->GetAttr("out_threshold")); } @@ -532,8 +528,7 @@ void GpuCpuSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("y_num_col_dims", 1); if (matmul_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_op->Op()->GetAttr("out_threshold")); } @@ -677,8 +672,7 @@ void GpuCpuReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("y_num_col_dims", 1); if (matmul_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_op->Op()->GetAttr("out_threshold")); } @@ -765,8 +759,7 @@ void GpuCpuFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("y_num_col_dims", 1); if (matmul_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_op->Op()->GetAttr("out_threshold")); } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 164a13d1560..03da1289205 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2949,6 +2949,84 @@ void patterns::DeleteQuantDequantFilterOpPattern::operator()() { any_op2->LinksFrom({quant_dequant_out}); } +void patterns::DeleteWeightQuantDequantLinearOpPattern::operator()() { + auto weight_dequantize_linear_op_x = + pattern->NewNode(weight_dequantize_linear_op_x_repr()) + ->AsInput() + ->assert_is_op_input("dequantize_linear", "X") + ->assert_is_persistable_var(); + + auto weight_dequantize_linear_op_scale = + pattern->NewNode(weight_dequantize_linear_op_scale_repr()) + ->AsInput() + ->assert_is_op_input("dequantize_linear", "Scale") + ->assert_is_persistable_var(); + + auto weight_dequantize_linear_op = + pattern->NewNode(weight_dequantize_linear_op_repr()) + ->assert_is_op("dequantize_linear"); + + auto weight_dequantize_linear_op_out = + pattern->NewNode(weight_dequantize_linear_op_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("dequantize_linear", "Y"); + + auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput(); + + weight_dequantize_linear_op + ->LinksFrom( + {weight_dequantize_linear_op_x, weight_dequantize_linear_op_scale}) + .LinksTo({weight_dequantize_linear_op_out}); + any_op2->LinksFrom({weight_dequantize_linear_op_out}); +} + +void patterns::DeleteQuantDequantLinearOpPattern::operator()() { + auto quantize_linear_op_x = pattern->NewNode(quantize_linear_op_x_repr()) + ->AsInput() + ->assert_is_op_input("quantize_linear", "X"); + + auto quantize_linear_op_scale = + pattern->NewNode(quantize_linear_op_scale_repr()) + ->AsInput() + ->assert_is_op_input("quantize_linear", "Scale") + ->assert_is_persistable_var(); + + auto quantize_linear_op = pattern->NewNode(quantize_linear_op_repr()) + ->assert_is_op("quantize_linear"); + + auto quantize_linear_op_out = + pattern->NewNode(quantize_linear_op_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("quantize_linear", "Y") + ->assert_is_op_input("dequantize_linear", "X") + ->assert_var_not_persistable(); + + // Can not add this node. Todo: Wangzheee + /* + auto dequantize_linear_op_scale = + pattern->NewNode(dequantize_linear_op_scale_repr()) + ->assert_is_op_input("dequantize_linear", "Scale") + ->AsIntermediate(); + */ + + auto dequantize_linear_op = pattern->NewNode(dequantize_linear_op_repr()) + ->assert_is_op("dequantize_linear"); + + auto dequantize_linear_op_out = + pattern->NewNode(dequantize_linear_op_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("dequantize_linear", "Y"); + + auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput(); + + quantize_linear_op + ->LinksFrom({quantize_linear_op_x, quantize_linear_op_scale}) + .LinksTo({quantize_linear_op_out}); + dequantize_linear_op->LinksFrom({quantize_linear_op_out}) + .LinksTo({dequantize_linear_op_out}); + any_op2->LinksFrom({dequantize_linear_op_out}); +} + PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( const std::string &op_name, bool with_reshape_xshape, bool with_transpose_xshape) { @@ -3311,25 +3389,14 @@ PDNode *patterns::LayerNorm::operator()() { return shift_out; } -// Add support int8 flag +// Add support int8 flag and out_threshold PDNode *patterns::AddSupportInt8::operator()() { - auto prev_op = - pattern->NewNode(prev_op_repr()) - ->assert_is_op() - ->assert_more([&](Node *node) { - return node->Op()->HasAttr("out_threshold") ? true : false; - }); - auto prev_out = pattern->NewNode(prev_out_repr())->assert_is_var(); - auto quant_op = - pattern->NewNode(quant_op_repr()) - ->assert_is_op() - ->assert_more([&](Node *node) { - return node->Op()->HasAttr("out_threshold") ? true : false; - }); + auto quant_op = pattern->NewNode(quant_op_repr())->assert_is_op(); auto quant_out = - pattern->NewNode(quant_out_repr())->assert_is_var()->AsOutput(); - prev_op->LinksTo({prev_out}); - prev_out->LinksTo({quant_op}); + pattern->NewNode(quant_out_repr()) + ->assert_is_var() + ->assert_more([&](Node *node) { return node->outputs.size() > 0; }) + ->AsOutput(); quant_op->LinksTo({quant_out}); return quant_out; } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 17c70ace301..1f253c6b910 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1702,6 +1702,40 @@ struct DeleteQuantDequantFilterOpPattern : public PatternBase { PATTERN_DECL_NODE(any_op2); }; +struct DeleteWeightQuantDequantLinearOpPattern : public PatternBase { + DeleteWeightQuantDequantLinearOpPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, + "delete_weight_quant_dequant_linear_op_pattern") {} + + void operator()(); + + PATTERN_DECL_NODE(weight_dequantize_linear_op_x); + PATTERN_DECL_NODE(weight_dequantize_linear_op_scale); + PATTERN_DECL_NODE(weight_dequantize_linear_op); + PATTERN_DECL_NODE(weight_dequantize_linear_op_out); + PATTERN_DECL_NODE(any_op2); +}; + +struct DeleteQuantDequantLinearOpPattern : public PatternBase { + DeleteQuantDequantLinearOpPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, + "delete_quant_dequant_linear_op_pattern") {} + + void operator()(); + + PATTERN_DECL_NODE(quantize_linear_op_x); + PATTERN_DECL_NODE(quantize_linear_op_scale); + PATTERN_DECL_NODE(quantize_linear_op); + PATTERN_DECL_NODE(quantize_linear_op_out); + PATTERN_DECL_NODE(dequantize_linear_op); + // PATTERN_DECL_NODE(dequantize_linear_op_scale); // Can not add this node. + // Todo: Wangzheee + PATTERN_DECL_NODE(dequantize_linear_op_out); + PATTERN_DECL_NODE(any_op2); +}; + // Reshape + Transpose + Matmul // named nodes: // reshape_op, reshape_out, reshape_xshape, @@ -1887,8 +1921,6 @@ struct AddSupportInt8 : public PatternBase { : PatternBase(pattern, name_scope, "Add_support_int8") {} PDNode* operator()(); - PATTERN_DECL_NODE(prev_op); - PATTERN_DECL_NODE(prev_out); PATTERN_DECL_NODE(quant_op); PATTERN_DECL_NODE(quant_out); }; diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index 989b5460743..a8595d55b31 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -862,43 +862,30 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, multihead_op_desc.SetAttr("head_number", head_number); auto* mul0_op_desc = mul0->Op(); - auto* mul1_op_desc = mul1->Op(); - auto* mul2_op_desc = mul2->Op(); - if (mul0_op_desc->HasAttr("enable_int8")) { - multihead_op_desc.SetAttr("enable_int8", - mul0_op_desc->GetAttr("enable_int8")); - // all mul op has same input. + + // all mul op has same input. + if (multihead_op_desc.HasAttr("Input_scale")) { multihead_op_desc.SetAttr("Input_scale", - mul0_op_desc->GetAttr("X_scale")); - auto weight_scale0 = BOOST_GET_CONST( - std::vector, mul0_op_desc->GetAttr("weight_scale")); - auto weight_scale1 = BOOST_GET_CONST( - std::vector, mul1_op_desc->GetAttr("weight_scale")); - auto weight_scale2 = BOOST_GET_CONST( - std::vector, mul2_op_desc->GetAttr("weight_scale")); - auto weight_max = std::max(weight_scale0, weight_scale1); - weight_max = std::max(weight_max, weight_scale2); - multihead_op_desc.SetAttr("weight_scale", weight_max); - - auto* add0_op_desc = eltadd0->Op(); - auto* add1_op_desc = eltadd1->Op(); - auto* add2_op_desc = eltadd2->Op(); - if (add0_op_desc->HasAttr("out_threshold")) { - auto out_scale0 = - BOOST_GET_CONST(float, add0_op_desc->GetAttr("out_threshold")); - auto out_scale1 = - BOOST_GET_CONST(float, add1_op_desc->GetAttr("out_threshold")); - auto out_scale2 = - BOOST_GET_CONST(float, add2_op_desc->GetAttr("out_threshold")); - auto out_scale_max = std::max(out_scale0, out_scale1); - out_scale_max = std::max(out_scale_max, out_scale2); - multihead_op_desc.SetAttr("fc_out_threshold", out_scale_max); - } + mul0_op_desc->GetAttr("Input_scale")); + } + auto* add0_op_desc = eltadd0->Op(); + auto* add1_op_desc = eltadd1->Op(); + auto* add2_op_desc = eltadd2->Op(); + if (add0_op_desc->HasAttr("out_threshold")) { + auto out_scale0 = + BOOST_GET_CONST(float, add0_op_desc->GetAttr("out_threshold")); + auto out_scale1 = + BOOST_GET_CONST(float, add1_op_desc->GetAttr("out_threshold")); + auto out_scale2 = + BOOST_GET_CONST(float, add2_op_desc->GetAttr("out_threshold")); + auto out_scale_max = std::max(out_scale0, out_scale1); + out_scale_max = std::max(out_scale_max, out_scale2); + multihead_op_desc.SetAttr("fc_out_threshold", out_scale_max); } auto* softmax_qk_op_desc = softmax_qk->Op(); auto* matmul_qk_op_desc = matmul_qk->Op(); - if (matmul_qk_op_desc->HasAttr("X_scale")) { + if (matmul_qk_op_desc->HasAttr("Input_scale")) { multihead_op_desc.SetAttr("qkv2context_plugin_int8", true); if (softmax_qk_op_desc->HasAttr("out_threshold")) { auto qkv_plugin_scale = BOOST_GET_CONST( diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 619fe7ab4f7..281e0b99106 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -341,7 +341,6 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, Node* output_scale = subgraph.at(pattern.GetPDNode("output_scale_node")); Node* output_act = subgraph.at(pattern.GetPDNode("output_act_node")); int bit_length = BOOST_GET_CONST(int, quant->Op()->GetAttr("bit_length")); - int range = ((1 << (bit_length - 1)) - 1); // Get input scale from tensor std::string input_scale_var_name = quant->Op()->Input("InScale").front(); @@ -356,7 +355,7 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, "Input scale tensor's place should be CPU.")); const float* input_scale_data = input_scale_tensor.data(); float in_scale = input_scale_data[0]; - float scale_value = in_scale / range; + float scale_value = in_scale; // Set input scale in attr, and relink nodes std::string input_act_name = input_act->Var()->Name(); @@ -369,11 +368,10 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, quantized_op_type == "conv2d_fusion" || quantized_op_type == "depthwise_conv2d" || quantized_op_type == "fc" || - quantized_op_type == "conv2d_transpose") { + quantized_op_type == "conv2d_transpose" || + quantized_op_type == "mul" || quantized_op_type == "matmul" || + quantized_op_type == "matmul_v2") { op_desc->SetAttr("Input_scale", scale_value); - } else if (quantized_op_type == "mul" || quantized_op_type == "matmul" || - quantized_op_type == "matmul_v2") { - op_desc->SetAttr("X_scale", scale_value); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported quantized op type %s.", quantized_op_type)); @@ -619,7 +617,6 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, new_op_desc.SetInput("X", {new_input}); new_op_desc.SetOutput("Out", {new_output}); } - new_op_desc.SetAttr("weight_scale", weight_scale); new_op_desc.Flush(); auto* new_op = graph->CreateOpNode(&new_op_desc); IR_NODE_LINK_TO(quantized_op_input_node, new_op); diff --git a/paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc index 3caaf08dc9c..d3211c08414 100644 --- a/paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc +++ b/paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc @@ -297,11 +297,24 @@ void TrtMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("transpose_Y", matmul_op->Op()->GetAttr("transpose_Y")); if (matmul_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_op->Op()->GetAttr("out_threshold")); } + + bool inscale_flag = false; + bool outscale_flag = false; + + if (matmul_op->Op()->HasAttr("X")) { + desc.SetAttr("X", matmul_op->Op()->GetAttr("X")); + inscale_flag = true; + } + if (matmul_op->Op()->HasAttr("Out")) { + desc.SetAttr("Out", matmul_op->Op()->GetAttr("Out")); + outscale_flag = true; + } + desc.SetAttr("support_int8", inscale_flag && outscale_flag); + auto mul_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(matmul_in_x, mul_node); IR_NODE_LINK_TO(matmul_in_y, mul_node); @@ -370,12 +383,23 @@ void TrtMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("transpose_Y", matmul_v2_op->Op()->GetAttr("trans_y")); if (matmul_v2_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", - matmul_v2_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_v2_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_v2_op->Op()->GetAttr("out_threshold")); } + + bool inscale_flag = false; + bool outscale_flag = false; + if (matmul_v2_op->Op()->HasAttr("X")) { + desc.SetAttr("X", matmul_v2_op->Op()->GetAttr("X")); + inscale_flag = true; + } + if (matmul_v2_op->Op()->HasAttr("Out")) { + desc.SetAttr("Out", matmul_v2_op->Op()->GetAttr("Out")); + outscale_flag = true; + } + desc.SetAttr("support_int8", inscale_flag && outscale_flag); + auto mul_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(matmul_v2_in_x, mul_node); IR_NODE_LINK_TO(matmul_v2_in_y, mul_node); @@ -448,11 +472,23 @@ void TrtMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { } if (matmul_v2_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_v2_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_v2_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_v2_op->Op()->GetAttr("out_threshold")); } + + bool inscale_flag = false; + bool outscale_flag = false; + if (matmul_v2_op->Op()->HasAttr("X")) { + desc.SetAttr("X", matmul_v2_op->Op()->GetAttr("X")); + inscale_flag = true; + } + if (matmul_v2_op->Op()->HasAttr("Out")) { + desc.SetAttr("Out", matmul_v2_op->Op()->GetAttr("Out")); + outscale_flag = true; + } + desc.SetAttr("support_int8", inscale_flag && outscale_flag); + auto matmul_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(matmul_v2_in_x, matmul_node); IR_NODE_LINK_TO(matmul_v2_in_y, matmul_node); @@ -530,11 +566,24 @@ void TrtSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("y_num_col_dims", 1); if (matmul_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_op->Op()->GetAttr("out_threshold")); } + + bool inscale_flag_x = false; + bool outscale_flag = false; + + if (squeeze2_op->Op()->HasAttr("X")) { + desc.SetAttr("X", squeeze2_op->Op()->GetAttr("X")); + inscale_flag_x = true; + } + if (matmul_op->Op()->HasAttr("Out")) { + desc.SetAttr("Out", matmul_op->Op()->GetAttr("Out")); + outscale_flag = true; + } + desc.SetAttr("support_int8", inscale_flag_x && outscale_flag); + auto mul_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(squeeze2_in_x, mul_node); IR_NODE_LINK_TO(matmul_in_y, mul_node); @@ -675,11 +724,24 @@ void TrtReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("y_num_col_dims", 1); if (matmul_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_op->Op()->GetAttr("out_threshold")); } + + bool inscale_flag_x = false; + bool outscale_flag = false; + + if (reshape2_op->Op()->HasAttr("X")) { + desc.SetAttr("X", reshape2_op->Op()->GetAttr("X")); + inscale_flag_x = true; + } + if (matmul_op->Op()->HasAttr("Out")) { + desc.SetAttr("Out", matmul_op->Op()->GetAttr("Out")); + outscale_flag = true; + } + desc.SetAttr("support_int8", inscale_flag_x && outscale_flag); + if (!IsCompat(desc)) { LOG(WARNING) << "TrtReshape2MatmulFusePass in out mul op compat failed."; @@ -763,11 +825,24 @@ void TrtFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("y_num_col_dims", 1); if (matmul_op->Op()->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); - desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); - desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale")); desc.SetAttr("out_threshold", matmul_op->Op()->GetAttr("out_threshold")); } + + bool inscale_flag_x = false; + bool outscale_flag = false; + + if (flatten2_op->Op()->HasAttr("X")) { + desc.SetAttr("X", flatten2_op->Op()->GetAttr("X")); + inscale_flag_x = true; + } + if (matmul_op->Op()->HasAttr("Out")) { + desc.SetAttr("Out", matmul_op->Op()->GetAttr("Out")); + outscale_flag = true; + } + desc.SetAttr("support_int8", inscale_flag_x && outscale_flag); + auto mul_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(flatten2_in_x, mul_node); IR_NODE_LINK_TO(matmul_in_y, mul_node); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 95975d8f2a8..20418e37a7b 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -76,10 +76,13 @@ void PaddlePassBuilder::ClearPasses() { passes_.clear(); } const std::vector kTRTSubgraphPasses({ "adaptive_pool2d_convert_global_pass", - "shuffle_channel_detect_pass", // - "quant_conv2d_dequant_fuse_pass", // - "delete_quant_dequant_op_pass", // - "delete_quant_dequant_filter_op_pass", // + "shuffle_channel_detect_pass", // + "quant_conv2d_dequant_fuse_pass", // + "delete_quant_dequant_op_pass", // + "delete_quant_dequant_filter_op_pass", // + "delete_weight_dequant_linear_op_pass", // + "delete_quant_dequant_linear_op_pass", // + "add_support_int8_pass", // // "fc_fuse_pass", // "simplify_with_basic_ops_pass", // "embedding_eltwise_layernorm_fuse_pass", // @@ -98,9 +101,8 @@ const std::vector kTRTSubgraphPasses({ "trt_map_matmul_to_mul_pass", // "fc_fuse_pass", // "conv_elementwise_add_fuse_pass", // - "add_support_int8_pass", - "tensorrt_subgraph_pass", // - "conv_bn_fuse_pass", // + "tensorrt_subgraph_pass", // + "conv_bn_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be // guaranteed at least v7 // cudnn8.0 has memory leak problem in conv + eltwise + act, so we diff --git a/paddle/fluid/inference/tensorrt/convert/activation_op.cc b/paddle/fluid/inference/tensorrt/convert/activation_op.cc index e6a0ecf4aec..b86351e394b 100644 --- a/paddle/fluid/inference/tensorrt/convert/activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/activation_op.cc @@ -68,12 +68,6 @@ class ActivationOpConverter : public OpConverter { auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, op_type_, {output_name}, test_mode); - if (op_desc.HasAttr("out_scale")) { -#if IS_TRT_VERSION_GE(5130) - float out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("out_scale")); - engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale); -#endif - } } protected: diff --git a/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc b/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc index eba67c3c098..cc06f82ae39 100644 --- a/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc @@ -49,11 +49,11 @@ class AffineChannelOpConverter : public OpConverter { auto* scale_v = scope.FindVar(scale_name); auto* scale_t = scale_v->GetMutable(); - float* scale_ptr = engine_->GetWeightCPUData(scale_name, scale_t, false); + float* scale_ptr = engine_->GetWeightCPUData(scale_name, scale_t); auto* bias_v = scope.FindVar(bias_name); auto* bias_t = bias_v->GetMutable(); - float* bias_ptr = engine_->GetWeightCPUData(bias_name, bias_t, false); + float* bias_ptr = engine_->GetWeightCPUData(bias_name, bias_t); // tensorrt scalend layer only support spatial dims >= 2, // so nhwc is not availabe (spatial dims == 0) diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index a296a2641db..1b2abeac6c1 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -49,18 +49,11 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, if (enable_int8) { #if IS_TRT_VERSION_GE(5000) - float in_scale = - BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127; - auto weight_scale = - BOOST_GET_CONST(std::vector, op_desc.GetAttr("weight_scale")); - weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, - true, weight_scale); + float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); engine->SetTensorDynamicRange(X, in_scale); #endif - } else { - weight_data = - engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, false); } + weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t); PADDLE_ENFORCE_EQ(Y_t->dims().size(), 4UL, platform::errors::InvalidArgument( @@ -115,7 +108,7 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, auto* bias_tensor = scope.GetVar(op_desc.Input("Bias").front()); auto* bias_tensor_data = bias_tensor->GetMutable(); bias_data = engine->GetWeightCPUData(op_desc.Input("Bias").front(), - bias_tensor_data, false); + bias_tensor_data); bias_size = static_cast(bias_tensor_data->numel()); } diff --git a/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc index dae92264d2c..dbb2786ed78 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc @@ -48,17 +48,10 @@ void ConvertConv3d(TensorRTEngine* engine, const framework::proto::OpDesc& op, bool enable_int8 = op_desc.HasAttr("enable_int8"); if (enable_int8) { - float in_scale = - BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127; - auto weight_scale = - BOOST_GET_CONST(std::vector, op_desc.GetAttr("weight_scale")); - weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, - true, weight_scale); + float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); engine->SetTensorDynamicRange(X, in_scale); - } else { - weight_data = - engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, false); } + weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t); PADDLE_ENFORCE_EQ(Y_t->dims().size(), 5UL, platform::errors::InvalidArgument( diff --git a/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc b/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc index d8534a4183b..2bbe6ea3d2f 100644 --- a/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc @@ -47,8 +47,7 @@ class DeformableConvOpConverter : public OpConverter { auto* filter_var = scope.FindVar(filter_name); auto* filter_tensor = filter_var->GetMutable(); - float* filter_data = - engine_->GetWeightCPUData(filter_name, filter_tensor, false); + float* filter_data = engine_->GetWeightCPUData(filter_name, filter_tensor); const int c_o = filter_tensor->dims()[0]; const int c_i = filter_tensor->dims()[1]; diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index a66a97b4be9..8fd0e1bbd06 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -51,8 +51,7 @@ class ElementwiseWeightOpConverter : public OpConverter { auto* Y_t = Y_v->GetMutable(); float* weight_data = nullptr; auto output_name = op_desc.Output("Out")[0]; - weight_data = - engine_->GetWeightCPUData(op_desc.Input("Y").front(), Y_t, false); + weight_data = engine_->GetWeightCPUData(op_desc.Input("Y").front(), Y_t); nvinfer1::Dims dims_x = X->getDimensions(); auto regist_eltwise_weight = [&](nvinfer1::ScaleMode scale_mode) { @@ -112,13 +111,6 @@ class ElementwiseWeightOpConverter : public OpConverter { RreplenishLayerAndOutput(layer, "elementwise_" + op_type_, {output_name}, test_mode); } - if (op_desc.HasAttr("enable_int8")) { -#if IS_TRT_VERSION_GE(5000) - CHECK(op_desc.HasAttr("X_scale")); - float x_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale")); - engine_->SetTensorDynamicRange(X, x_scale); -#endif - } }; if (engine_->with_dynamic_shape()) { @@ -222,16 +214,6 @@ class ElementwiseTensorOpConverter : public OpConverter { auto common_func = [&](nvinfer1::ILayer* layer) { RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode); - if (op_desc.HasAttr("enable_int8")) { -#if IS_TRT_VERSION_GE(5000) - CHECK(op_desc.HasAttr("X_scale")); - CHECK(op_desc.HasAttr("Y_scale")); - float x_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale")); - float y_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Y_scale")); - engine_->SetTensorDynamicRange(X, x_scale); - engine_->SetTensorDynamicRange(Y, y_scale); -#endif - } }; if (dims_x.nbDims == dims_y.nbDims) { diff --git a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc index 9741aab32de..7a494860e6f 100644 --- a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc @@ -77,7 +77,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { auto* temp_tensor = temp_var->GetMutable(); (*dims) = temp_tensor->dims(); - auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false); + auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor); return temp_data; }; diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index bdea14c9e9f..a631332dae3 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -113,22 +113,20 @@ class FcOpConverter : public OpConverter { // assigned from CPU memory, which can't be avoided. float* weight_data = nullptr; bool enable_int8 = op_desc.HasAttr("enable_int8"); - float in_scale = 0.; - if (enable_int8) { -#if IS_TRT_VERSION_GE(5000) - CHECK(op_desc.HasAttr(i_name + "_scale")); - in_scale = - BOOST_GET_CONST(float, op_desc.GetAttr(i_name + "_scale")) * 127; - auto weight_scale = - BOOST_GET_CONST(std::vector, op_desc.GetAttr("weight_scale")); - weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(), - Y_t, true, weight_scale); + bool support_int8 = false; + if (op_desc.HasAttr("support_int8")) { + support_int8 = BOOST_GET_CONST(bool, op_desc.GetAttr("support_int8")); + } + float in_scale = 0; + if (enable_int8 || support_int8) { + if (enable_int8) { + in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); + } else { + in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X")); + } engine_->SetTensorDynamicRange(X, in_scale); -#endif - } else { - weight_data = - engine_->GetWeightCPUData(op_desc.Input(w_name).front(), Y_t, false); } + weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(), Y_t); PADDLE_ENFORCE_EQ(Y_t->dims().size(), 2UL, platform::errors::InvalidArgument( @@ -148,14 +146,18 @@ class FcOpConverter : public OpConverter { auto regist_fc = [&](nvinfer1::ITensor* inputs, int n_output, TensorRTEngine::Weight& weight, TensorRTEngine::Weight& bias) { - if (enable_int8) { + if (enable_int8 || support_int8) { // add conv layer - PADDLE_ENFORCE_EQ( - op_desc.HasAttr("out_threshold"), true, - platform::errors::InvalidArgument( - "must have out threshold in fc layers in int8 mode")); - float out_scale = - BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); + float out_scale = 0; + if (enable_int8) { + PADDLE_ENFORCE_EQ( + op_desc.HasAttr("out_threshold"), true, + platform::errors::InvalidArgument( + "must have out threshold in fc layers in int8 mode")); + out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); + } else { + out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Out")); + } nvinfer1::DimsHW nv_ksize(1, 1); auto* fc_layer_int8 = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output, @@ -235,8 +237,7 @@ class FcOpConverter : public OpConverter { if (with_bias) { auto* b_v = scope.GetVar(op_desc.Input("Bias").front()); auto* b_t = b_v->GetMutable(); - bias_data = - engine_->GetWeightCPUData(op_desc.Input("Bias").front(), b_t, false); + bias_data = engine_->GetWeightCPUData(op_desc.Input("Bias").front(), b_t); bias_num = b_t->numel(); } TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, @@ -251,7 +252,7 @@ class FcOpConverter : public OpConverter { // not add Shuffle layer in ernie's multihead. if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 && x_dim.d[3] == 1 && x_num_col_dims == 2) { - if (enable_int8) { + if (enable_int8 || support_int8) { // add conv1x1 layer nvinfer1::DimsHW nv_ksize(1, 1); auto* fc_layer_int8 = @@ -265,8 +266,13 @@ class FcOpConverter : public OpConverter { op_desc.HasAttr("out_threshold"), true, platform::errors::InvalidArgument( "must have out threshold in fc layers in int8 mode")); - float out_scale = - BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); + float out_scale = 0; + if (enable_int8) { + out_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); + } else { + out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Out")); + } engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0), out_scale); nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER( @@ -308,7 +314,7 @@ class FcOpConverter : public OpConverter { auto* reshape_before_fc_layer = reshape_before_fc(X, x_dim, x_num_col_dims, output_name); auto* reshape_itensor = reshape_before_fc_layer->getOutput(0); - if (enable_int8) { + if (enable_int8 || support_int8) { engine_->SetTensorDynamicRange(reshape_itensor, in_scale); } regist_fc(reshape_itensor, n_output, weight, bias); diff --git a/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc b/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc index b3c1f986aa0..910a807d362 100644 --- a/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc @@ -48,7 +48,7 @@ class GroupNormOpConverter : public OpConverter { auto* temp_tensor = temp_var->GetMutable(); (*dims) = temp_tensor->dims(); - auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false); + auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor); return temp_data; }; diff --git a/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc b/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc index c6dbfc83220..c7a551b7436 100644 --- a/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc @@ -49,8 +49,8 @@ class LeakyReluOpConverter : public OpConverter { bool enable_int8 = op_desc.HasAttr("enable_int8"); if (enable_int8) { - CHECK(op_desc.HasAttr("X_scale")); - float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale")); + CHECK(op_desc.HasAttr("Input_scale")); + float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); engine_->SetTensorDynamicRange(input, in_scale); } #else diff --git a/paddle/fluid/inference/tensorrt/convert/matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/matmul_op.cc index b2e76b9a0e6..7568f67d64d 100644 --- a/paddle/fluid/inference/tensorrt/convert/matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/matmul_op.cc @@ -64,7 +64,9 @@ class MatMulOpConverter : public OpConverter { : nvinfer1::MatrixOperation::kNONE; if (op_desc.HasAttr("support_int8") && - engine_->precision() == AnalysisConfig::Precision::kInt8) { + BOOST_GET_CONST(bool, op_desc.GetAttr("support_int8")) && + engine_->precision() == AnalysisConfig::Precision::kInt8 && + platform::GetGPUComputeCapability(0) >= 75) { if (engine_->with_dynamic_shape()) { VLOG(3) << "Convert a fluid matmul_op_int8_dynamic to TensorRT " "MatmulPluginLayer"; diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index f19b21d3e63..21c79f0edd2 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -40,22 +40,16 @@ class MultiheadMatMulOpConverter : public OpConverter { auto* bias_t = bias_v->GetMutable(); float* weight_data = nullptr; - bool enable_int8 = op_desc.HasAttr("enable_int8"); bool qkv2context_plugin_int8 = op_desc.HasAttr("qkv2context_plugin_int8"); float in_scale = 0.; - if (enable_int8) { - in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127; - auto weight_scale = - BOOST_GET_CONST(std::vector, op_desc.GetAttr("weight_scale")); - weight_data = - engine_->GetWeightCPUData(weight_name, weight_t, true, weight_scale); + if (op_desc.HasAttr("Input_scale")) { + in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); engine_->SetTensorDynamicRange(input, in_scale); - } else { - weight_data = engine_->GetWeightCPUData(weight_name, weight_t, false); } + weight_data = engine_->GetWeightCPUData(weight_name, weight_t); - float* bias_data = engine_->GetWeightCPUData(bias_name, bias_t, false); + float* bias_data = engine_->GetWeightCPUData(bias_name, bias_t); std::vector weight_data_tmp; weight_data_tmp.reserve(weight_t->numel()); memcpy(weight_data_tmp.data(), weight_data, @@ -85,6 +79,10 @@ class MultiheadMatMulOpConverter : public OpConverter { if (engine_->with_dynamic_shape()) { if (engine_->use_oss()) { + if (engine_->precision() == AnalysisConfig::Precision::kFloat32) { + PADDLE_THROW(platform::errors::Fatal( + "use use_oss must be int8 or half, not float32.")); + } nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT, static_cast(weight_data), static_cast(weight_t->numel())}; @@ -93,7 +91,7 @@ class MultiheadMatMulOpConverter : public OpConverter { static_cast(bias_t->numel())}; if (engine_->with_interleaved()) { VLOG(4) << "fused multihead_matmul op: use_oss and with_interleaved"; - if (!enable_int8) { + if (!op_desc.HasAttr("Input_scale")) { PADDLE_THROW( platform::errors::Fatal("use with_interleaved must be int8.")); } @@ -213,7 +211,7 @@ class MultiheadMatMulOpConverter : public OpConverter { nvinfer1::ILayer* fc_layer = nullptr; float dp_probs = 1.0 / 127.0; - if (enable_int8) { + if (op_desc.HasAttr("Input_scale")) { nvinfer1::DimsHW nv_ksize(1, 1); fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *input, n, nv_ksize, weight, bias); @@ -222,7 +220,7 @@ class MultiheadMatMulOpConverter : public OpConverter { weight, bias); } - if (enable_int8) { + if (op_desc.HasAttr("fc_out_threshold")) { PADDLE_ENFORCE_EQ(op_desc.HasAttr("fc_out_threshold"), true, platform::errors::InvalidArgument( "must have out threshold in multihead layers " @@ -241,14 +239,10 @@ class MultiheadMatMulOpConverter : public OpConverter { auto creator = GetPluginRegistry()->getPluginCreator( "CustomQKVToContextPluginDynamic", "2"); assert(creator != nullptr); - int type = static_cast((engine_->WithFp16() == 1) - ? nvinfer1::DataType::kHALF - : nvinfer1::DataType::kFLOAT); - if (enable_int8) { - type = static_cast(nvinfer1::DataType::kHALF); - if (qkv2context_plugin_int8) { - type = static_cast(nvinfer1::DataType::kINT8); - } + int type = static_cast(nvinfer1::DataType::kHALF); + if (qkv2context_plugin_int8 && + (engine_->precision() == AnalysisConfig::Precision::kInt8)) { + type = static_cast(nvinfer1::DataType::kINT8); } bool has_mask = true; int var_seqlen = 1; @@ -335,7 +329,7 @@ class MultiheadMatMulOpConverter : public OpConverter { reshape_before_fc_dim.d[4] = 1; auto* reshape_before_fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); - if (enable_int8) { + if (op_desc.HasAttr("Input_scale")) { engine_->SetTensorDynamicRange(reshape_before_fc_layer->getOutput(0), in_scale); } @@ -346,7 +340,7 @@ class MultiheadMatMulOpConverter : public OpConverter { // add layer fc nvinfer1::ILayer* fc_layer = nullptr; - if (enable_int8) { + if (op_desc.HasAttr("Input_scale")) { nvinfer1::DimsHW nv_ksize(1, 1); fc_layer = TRT_ENGINE_ADD_LAYER( engine_, Convolution, *reshape_before_fc_layer->getOutput(0), n, @@ -357,7 +351,7 @@ class MultiheadMatMulOpConverter : public OpConverter { n, weight.get(), bias.get()); } - if (enable_int8) { + if (op_desc.HasAttr("fc_out_threshold")) { PADDLE_ENFORCE_EQ( op_desc.HasAttr("fc_out_threshold"), true, platform::errors::InvalidArgument( @@ -382,8 +376,8 @@ class MultiheadMatMulOpConverter : public OpConverter { bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); - if (enable_int8) { - with_fp16 = 1; + if (engine_->precision() == AnalysisConfig::Precision::kInt8) { + with_fp16 = true; } plugin::DynamicPluginTensorRT* plugin = new plugin::QkvToContextPluginDynamic(hidden_in, head_number, diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 7e0c8bf1da1..f7eb7f859af 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -145,42 +145,68 @@ class OpConverter { (*it)(op, scope, test_mode); size_t output_num = op_desc.OutputNames().size(); - if (output_num == 1) { // The number of output is 1 - if (op_desc.HasAttr("out_threshold")) { - float out_scale = - BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); - std::string output_name = ""; - if (op_desc.HasOutput("Output")) { - output_name = op_desc.Output("Output").front(); - } else if (op_desc.HasOutput("Out")) { - output_name = op_desc.Output("Out").front(); - } else if (op_desc.HasOutput("Y")) { - output_name = op_desc.Output("Y").front(); - } else { - PADDLE_THROW( - platform::errors::NotFound("Op %s has out threshold but doesn't " - "have an output named \"Output\", " - "\"Out\" or \"Y\".", - op_desc.Type())); - } + // only one out settensordynamicRange + if (op_desc.HasAttr("out_threshold")) { + float out_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); + std::string output_name = ""; + if (op_desc.HasOutput("Output")) { + output_name = op_desc.Output("Output").front(); + } else if (op_desc.HasOutput("Out")) { + output_name = op_desc.Output("Out").front(); + } else if (op_desc.HasOutput("Y")) { + output_name = op_desc.Output("Y").front(); + } else { + PADDLE_THROW( + platform::errors::NotFound("Op %s has out threshold but doesn't " + "have an output named \"Output\", " + "\"Out\" or \"Y\".", + op_desc.Type())); + } + auto* output_itensor = engine->GetITensor(output_name); + engine->SetTensorDynamicRange(output_itensor, out_scale); + VLOG(1) << "Set out scale = " << out_scale << " for tensor " + << output_name << "."; + } + // outs settensordynamicRange + for (size_t i = 0; i < output_num; ++i) { + if (op_desc.HasAttr("out_" + std::to_string(i) + "_threshold")) { + float out_scale = BOOST_GET_CONST( + float, op_desc.GetAttr("out_" + std::to_string(i) + "_threshold")); + std::string output_name = + op_desc.Output(op_desc.OutputNames()[i]).front(); auto* output_itensor = engine->GetITensor(output_name); engine->SetTensorDynamicRange(output_itensor, out_scale); VLOG(1) << "Set out scale = " << out_scale << " for tensor " << output_name << "."; } - } else if (output_num > 1) { // The number of outputs greater than 1 - for (size_t i = 0; i < output_num; ++i) { - if (op_desc.HasAttr("out_" + std::to_string(i) + "_threshold")) { - float out_scale = BOOST_GET_CONST( - float, - op_desc.GetAttr("out_" + std::to_string(i) + "_threshold")); - std::string output_name = - op_desc.Output(op_desc.OutputNames()[i]).front(); - auto* output_itensor = engine->GetITensor(output_name); - engine->SetTensorDynamicRange(output_itensor, out_scale); - VLOG(1) << "Set out scale = " << out_scale << " for tensor " - << output_name << "."; - } + } + + // quant_dequant_linear support for paddle trt + + std::vector inputs_name = op_desc.InputNames(); + std::vector outputs_name = op_desc.OutputNames(); + + for (size_t i = 0; i < inputs_name.size(); i++) { + if (op_desc.HasAttr(inputs_name[i])) { + std::string input_tensor_name = op_desc.Input(inputs_name[i])[0]; + auto* input_itensor = engine->GetITensor(input_tensor_name); + float input_scale = + BOOST_GET_CONST(float, op_desc.GetAttr(inputs_name[i])); + engine->SetTensorDynamicRange(input_itensor, input_scale); + VLOG(1) << "Set input tensor scale = " << input_scale + << " for tensor: " << input_tensor_name << "."; + } + } + for (size_t i = 0; i < outputs_name.size(); i++) { + if (op_desc.HasAttr(outputs_name[i])) { + std::string output_tensor_name = op_desc.Output(outputs_name[i])[0]; + auto* output_itensor = engine->GetITensor(output_tensor_name); + float output_scale = + BOOST_GET_CONST(float, op_desc.GetAttr(outputs_name[i])); + engine->SetTensorDynamicRange(output_itensor, output_scale); + VLOG(1) << "Set output tensor scale = " << output_scale + << " for tensor: " << output_tensor_name << "."; } } } diff --git a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc index 7b65d2d7c97..7824a0f1e29 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc @@ -132,11 +132,10 @@ class Pool2dOpConverter : public OpConverter { } if (op_desc.HasAttr("enable_int8")) { -#if IS_TRT_VERSION_GE(5000) - CHECK(op_desc.HasAttr("X_scale")); - float input_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale")); + CHECK(op_desc.HasAttr("Input_scale")); + float input_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); engine_->SetTensorDynamicRange(input1, input_scale); -#endif } std::vector real_paddings = paddings; diff --git a/paddle/fluid/inference/tensorrt/convert/pool3d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool3d_op.cc index 5a306f622ad..665bf9c8d22 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool3d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool3d_op.cc @@ -123,8 +123,9 @@ class Pool3dOpConverter : public OpConverter { nvinfer1::Dims3 nv_paddings(paddings[0], paddings[1], paddings[2]); nvinfer1::ILayer *layer = nullptr; if (op_desc.HasAttr("enable_int8")) { - CHECK(op_desc.HasAttr("X_scale")); - float input_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale")); + CHECK(op_desc.HasAttr("Input_scale")); + float input_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); engine_->SetTensorDynamicRange(input1, input_scale); } diff --git a/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc index daa3b186ab4..87fdbb71a3f 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc @@ -70,7 +70,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { auto* temp_tensor = temp_var->GetMutable(); (*dims) = temp_tensor->dims(); - auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false); + auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor); return temp_data; }; diff --git a/paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc index d9eca65fc45..8053135cc45 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc @@ -48,7 +48,7 @@ class PrelnSkipLayerNormOpConverter : public OpConverter { auto* temp_tensor = temp_var->GetMutable(); (*dims) = temp_tensor->dims(); - auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false); + auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor); return temp_data; }; diff --git a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc index 9e81d1177cf..d5b5d9bc81b 100644 --- a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc @@ -57,8 +57,8 @@ class PReluOpConverter : public OpConverter { layer = engine_->AddDynamicPlugin(&input, input_num, plugin); } else { #if IS_TRT_VERSION_GE(7000) - float* alpha_weight_data = engine_->GetWeightCPUData( - op_desc.Input("Alpha")[0], alpha_tensor, false); + float* alpha_weight_data = + engine_->GetWeightCPUData(op_desc.Input("Alpha")[0], alpha_tensor); TensorRTEngine::Weight alpha_weight{ nvinfer1::DataType::kFLOAT, static_cast(alpha_weight_data), static_cast(alpha_tensor->numel())}; diff --git a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc index 753cd707276..831e1173117 100644 --- a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc @@ -40,7 +40,7 @@ class SkipLayerNormOpConverter : public OpConverter { auto* temp_tensor = temp_var->GetMutable(); (*dims) = temp_tensor->dims(); - auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false); + auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor); return temp_data; }; diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 794475dfc10..33386c746ae 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -356,9 +356,7 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) { } float *TensorRTEngine::GetWeightCPUData(const std::string &name, - framework::Tensor *weight_tensor, - bool enable_int8, - const std::vector &scale) { + framework::Tensor *weight_tensor) { static int name_suffix_counter = 0; std::string name_suffix = std::to_string(name_suffix_counter); std::string splitter = "__"; diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index d53a8923af6..f781cd0cb3a 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -389,8 +389,7 @@ class TensorRTEngine { } float* GetWeightCPUData(const std::string& name, - framework::Tensor* weight_tensor, bool enable_int8, - const std::vector& scale = {}); + framework::Tensor* weight_tensor); // A pointer to CPU memory is needed of the TRT weight. // Before TRT runs, fluid loads weight into GPU storage. diff --git a/paddle/fluid/operators/compat/dequantize_linear.pbtxt b/paddle/fluid/operators/compat/dequantize_linear.pbtxt new file mode 100644 index 00000000000..73b61f8bc29 --- /dev/null +++ b/paddle/fluid/operators/compat/dequantize_linear.pbtxt @@ -0,0 +1,25 @@ +type: "dequantize_linear" +def { + inputs { + name: "X" + } + inputs { + name: "Scale" + } + inputs { + name: "ZeroPoint" + } + outputs { + name: "Y" + } + attrs { + name: "bit_length" + type: INT + } + attrs { + name: "quant_axis" + type: INT + } +} +extra { +} diff --git a/paddle/fluid/operators/compat/mul.pbtxt b/paddle/fluid/operators/compat/mul.pbtxt index 617775eaaae..056f799c6c4 100644 --- a/paddle/fluid/operators/compat/mul.pbtxt +++ b/paddle/fluid/operators/compat/mul.pbtxt @@ -60,15 +60,7 @@ extra { type: BOOLEAN } attrs { - name: "X_scale" - type: FLOAT - } - attrs { - name: "weight_scale" - type: FLOAT - } - attrs { - name: "out_scale" + name: "Input_scale" type: FLOAT } attrs { diff --git a/paddle/fluid/operators/compat/quantize_linear.pbtxt b/paddle/fluid/operators/compat/quantize_linear.pbtxt new file mode 100644 index 00000000000..7a3ca515029 --- /dev/null +++ b/paddle/fluid/operators/compat/quantize_linear.pbtxt @@ -0,0 +1,25 @@ +type: "quantize_linear" +def { + inputs { + name: "X" + } + inputs { + name: "Scale" + } + inputs { + name: "ZeroPoint" + } + outputs { + name: "Y" + } + attrs { + name: "bit_length" + type: INT + } + attrs { + name: "quant_axis" + type: INT + } +} +extra { +} diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py index 97a94ef348a..26066be7dc7 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py @@ -491,8 +491,7 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest): "x_num_col_dims": 2, "y_num_col_dims": 1, "enable_int8": True, - "X_scale": 1.0, - "weight_scale": [1.0], + "Input_scale": 1.0, }, { "axis": 2, "out_threshold": 1.0, @@ -504,8 +503,7 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest): "x_num_col_dims": 2, "y_num_col_dims": 1, "enable_int8": True, - "X_scale": 1.0, - "weight_scale": [1.0], + "Input_scale": 1.0, }, { "axis": 2, "out_threshold": 1.0, @@ -517,8 +515,7 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest): "x_num_col_dims": 2, "y_num_col_dims": 1, "enable_int8": True, - "X_scale": 1.0, - "weight_scale": [1.0], + "Input_scale": 1.0, }, { "axis": 2, "out_threshold": 1.0, -- GitLab