From f706d95dfe9301e18ee6575c3e58e7ba37d6e78a Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Tue, 16 Aug 2022 17:21:15 +0800 Subject: [PATCH] convert multihead to oss (#45019) * convert multihead to oss * fix:bug * fix:delete const cast * fix:don't support bias_qk * add vit pass * fix:convert bug and add preln_residual_bias * support length=-1 * add UT for convert * add no_bias_qk support for gpu_multihead_op * delete infer_shape depends on bias_qk * oss just can be used in T4 and A* * fix:change api for ROCM CI --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../framework/ir/graph_pattern_detector.cc | 148 ++++++++++ .../framework/ir/graph_pattern_detector.h | 52 ++++ .../ir/preln_residual_bias_fuse_pass.h | 2 +- .../framework/ir/vit_attention_fuse_pass.cc | 147 ++++++++++ .../framework/ir/vit_attention_fuse_pass.h | 41 +++ .../inference/api/paddle_pass_builder.cc | 1 + .../tensorrt/convert/multihead_matmul_op.cc | 232 +++++++++++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 40 +-- .../operators/fused/multihead_matmul_op.cc | 17 +- .../operators/fused/multihead_matmul_op.cu | 20 +- .../test_trt_convert_multihead_matmul.py | 266 ++++++++++++++++++ 12 files changed, 929 insertions(+), 38 deletions(-) create mode 100644 paddle/fluid/framework/ir/vit_attention_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/vit_attention_fuse_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 857d68d74fa..cb80d8453d2 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -117,6 +117,7 @@ pass_library(graph_viz_pass base) pass_library(lock_free_optimize_pass base DEPS string_helper) pass_library(fc_fuse_pass inference) pass_library(attention_lstm_fuse_pass inference) +pass_library(vit_attention_fuse_pass inference) pass_library(fc_lstm_fuse_pass inference) pass_library(embedding_fc_lstm_fuse_pass inference) pass_library(fc_gru_fuse_pass inference) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 5f8dcf9b7e5..271b9b9d029 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2334,6 +2334,154 @@ PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) { return act_out; } +PDNode *patterns::VitAttention::operator()(PDNode *in) { + in->AsInput(); + std::unordered_set matmul_ops{"matmul", "matmul_v2"}; + + auto matmul0_op = + pattern->NewNode(matmul0_op_repr())->assert_is_ops(matmul_ops); + auto matmul0_in_y = pattern->NewNode(matmul0_in_y_repr()) + ->AsInput() + ->assert_is_ops_input(matmul_ops, "Y"); + auto matmul0_out = pattern->NewNode(matmul0_out_repr()) + ->assert_is_ops_output(matmul_ops, "Out") + ->assert_is_op_input("elementwise_add", "X") + ->AsIntermediate(); + + auto elementwise0_op = + pattern->NewNode(elementwise0_op_repr())->assert_is_op("elementwise_add"); + auto elementwise0_in_y = pattern->NewNode(elementwise0_in_y_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto elementwise0_out = pattern->NewNode(elementwise0_out_repr()) + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("reshape2", "X") + ->AsIntermediate(); + + auto reshape1_op = + pattern->NewNode(reshape1_op_repr())->assert_is_op("reshape2"); + auto reshape1_out = pattern->NewNode(reshape1_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("transpose2", "X") + ->AsIntermediate(); + + auto transpose1_op = + pattern->NewNode(transpose1_op_repr())->assert_is_op("transpose2"); + auto transpose1_out = pattern->NewNode(transpose1_out_repr()) + ->assert_is_op_output("transpose2", "Out") + ->assert_is_op_input("slice", "Input") + ->AsIntermediate(); + + auto slice1_op = pattern->NewNode(slice1_op_repr())->assert_is_op("slice"); + auto slice1_out = pattern->NewNode(slice1_out_repr()) + ->assert_is_op_output("slice", "Out") + ->assert_is_op_input("matmul_v2", "Y") + ->AsIntermediate(); + + auto slice2_op = pattern->NewNode(slice2_op_repr())->assert_is_op("slice"); + auto slice2_out = pattern->NewNode(slice2_out_repr()) + ->assert_is_op_output("slice", "Out") + ->assert_is_op_input("matmul_v2", "X") + ->AsIntermediate(); + + auto slice3_op = pattern->NewNode(slice3_op_repr())->assert_is_op("slice"); + auto slice3_out = pattern->NewNode(slice3_out_repr()) + ->assert_is_op_output("slice", "Out") + ->assert_is_op_input("transpose2", "X") + ->AsIntermediate(); + + auto transpose2_op = + pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2"); + auto transpose2_out = pattern->NewNode(transpose2_out_repr()) + ->assert_is_op_output("transpose2", "Out") + ->assert_is_op_input("matmul_v2", "Y") + ->AsIntermediate(); + + auto matmul1_op = + pattern->NewNode(matmul1_op_repr())->assert_is_op("matmul_v2"); + auto matmul1_out = pattern->NewNode(matmul1_out_repr()) + ->assert_is_op_output("matmul_v2", "Out") + ->assert_is_op_input("scale", "X") + ->AsIntermediate(); + + auto scale1_op = pattern->NewNode(scale1_op_repr())->assert_is_op("scale"); + auto scale1_out = pattern->NewNode(scale1_out_repr()) + ->assert_is_op_output("scale", "Out") + ->assert_is_op_input("softmax", "X") + ->AsIntermediate(); + + auto softmax1_op = + pattern->NewNode(softmax1_op_repr())->assert_is_op("softmax"); + auto softmax1_out = pattern->NewNode(softmax1_out_repr()) + ->assert_is_op_output("softmax", "Out") + ->assert_is_op_input("matmul_v2", "X") + ->AsIntermediate(); + + auto matmul2_op = + pattern->NewNode(matmul2_op_repr())->assert_is_op("matmul_v2"); + auto matmul2_out = pattern->NewNode(matmul2_out_repr()) + ->assert_is_op_output("matmul_v2", "Out") + ->assert_is_op_input("transpose2", "X") + ->AsIntermediate(); + + auto transpose3_op = + pattern->NewNode(transpose3_op_repr())->assert_is_op("transpose2"); + auto transpose3_out = pattern->NewNode(transpose3_out_repr()) + ->assert_is_op_output("transpose2", "Out") + ->assert_is_op_input("reshape2", "X") + ->AsIntermediate(); + + auto reshape2_op = + pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2"); + auto reshape2_out = pattern->NewNode(reshape2_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->AsOutput(); + + matmul0_op->LinksFrom({in, matmul0_in_y}); + matmul0_out->LinksFrom({matmul0_op}); + + elementwise0_op->LinksFrom({matmul0_out, elementwise0_in_y}); + elementwise0_out->LinksFrom({elementwise0_op}); + + reshape1_op->LinksFrom({elementwise0_out}); + reshape1_out->LinksFrom({reshape1_op}); + + transpose1_op->LinksFrom({reshape1_out}); + transpose1_out->LinksFrom({transpose1_op}); + + slice1_op->LinksFrom({transpose1_out}); + slice1_out->LinksFrom({slice1_op}); + + slice2_op->LinksFrom({transpose1_out}); + slice2_out->LinksFrom({slice2_op}); + + slice3_op->LinksFrom({transpose1_out}); + slice3_out->LinksFrom({slice3_op}); + + transpose2_op->LinksFrom({slice3_out}); + transpose2_out->LinksFrom({transpose2_op}); + + matmul1_op->LinksFrom({slice2_out, transpose2_out}); + matmul1_out->LinksFrom({matmul1_op}); + + scale1_op->LinksFrom({matmul1_out}); + scale1_out->LinksFrom({scale1_op}); + + softmax1_op->LinksFrom({scale1_out}); + softmax1_out->LinksFrom({softmax1_op}); + + matmul2_op->LinksFrom({slice1_out, softmax1_out}); + matmul2_out->LinksFrom({matmul2_op}); + + transpose3_op->LinksFrom({matmul2_out}); + transpose3_out->LinksFrom({transpose3_op}); + + reshape2_op->LinksFrom({transpose3_out}); + reshape2_out->LinksFrom({reshape2_op}); + + return reshape2_out; +} + PDNode *patterns::ConvElementwiseadd2Act::operator()(PDNode *conv_in) { auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); auto conv_filter = pattern->NewNode(conv_filter_repr()) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 9e2eb21b7f4..507fb83af4e 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1377,6 +1377,58 @@ struct PriorBox : public PatternBase { PATTERN_DECL_NODE(prior_box_variances); }; +// vit_attention +struct VitAttention : public PatternBase { + VitAttention(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "vit_attention") {} + + PDNode* operator()(PDNode* in); + + PATTERN_DECL_NODE(matmul0_op); + PATTERN_DECL_NODE(matmul0_in_y); + PATTERN_DECL_NODE(matmul0_out); + + PATTERN_DECL_NODE(elementwise0_op); + PATTERN_DECL_NODE(elementwise0_in_y); + PATTERN_DECL_NODE(elementwise0_out); + + PATTERN_DECL_NODE(reshape1_op); + PATTERN_DECL_NODE(reshape1_out); + + PATTERN_DECL_NODE(transpose1_op); + PATTERN_DECL_NODE(transpose1_out); + + PATTERN_DECL_NODE(slice1_op); + PATTERN_DECL_NODE(slice1_out); + + PATTERN_DECL_NODE(slice2_op); + PATTERN_DECL_NODE(slice2_out); + + PATTERN_DECL_NODE(slice3_op); + PATTERN_DECL_NODE(slice3_out); + + PATTERN_DECL_NODE(matmul2_op); + PATTERN_DECL_NODE(matmul2_out); + + PATTERN_DECL_NODE(matmul1_op); + PATTERN_DECL_NODE(matmul1_out); + + PATTERN_DECL_NODE(transpose2_op); + PATTERN_DECL_NODE(transpose2_out); + + PATTERN_DECL_NODE(scale1_op); + PATTERN_DECL_NODE(scale1_out); + + PATTERN_DECL_NODE(softmax1_op); + PATTERN_DECL_NODE(softmax1_out); + + PATTERN_DECL_NODE(transpose3_op); + PATTERN_DECL_NODE(transpose3_out); + + PATTERN_DECL_NODE(reshape2_op); + PATTERN_DECL_NODE(reshape2_out); +}; + // Conv + ElementwiseAdd + an activation // This pattern can further fuse the conv related ops after the conv+bn fusion. struct ConvElementwiseaddAct : public PatternBase { diff --git a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.h b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.h index ba74d9f49f5..a22bc6d517a 100644 --- a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.h +++ b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.h @@ -45,7 +45,7 @@ class PrelnResidualBiasFusePass : public FusePassBase { .IsTensor() .End() .AddAttr("axis") - .IsIntIn({0, -1}) + .IsIntIn({0, -1, 2}) .End(); AddOpCompat(OpCompat("layer_norm")) diff --git a/paddle/fluid/framework/ir/vit_attention_fuse_pass.cc b/paddle/fluid/framework/ir/vit_attention_fuse_pass.cc new file mode 100644 index 00000000000..819708b447c --- /dev/null +++ b/paddle/fluid/framework/ir/vit_attention_fuse_pass.cc @@ -0,0 +1,147 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/vit_attention_fuse_pass.h" + +#include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" + +#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); +#define GET_NODES \ + GET_IR_NODE(matmul0_op); \ + GET_IR_NODE(matmul0_in_y); \ + GET_IR_NODE(matmul0_out); \ + GET_IR_NODE(elementwise0_op); \ + GET_IR_NODE(elementwise0_in_y); \ + GET_IR_NODE(elementwise0_out); \ + GET_IR_NODE(reshape1_op); \ + GET_IR_NODE(reshape1_out); \ + GET_IR_NODE(transpose1_op); \ + GET_IR_NODE(transpose1_out); \ + GET_IR_NODE(slice1_op); \ + GET_IR_NODE(slice1_out); \ + GET_IR_NODE(slice2_op); \ + GET_IR_NODE(slice2_out); \ + GET_IR_NODE(slice3_op); \ + GET_IR_NODE(slice3_out); \ + GET_IR_NODE(matmul1_op); \ + GET_IR_NODE(matmul1_out); \ + GET_IR_NODE(scale1_op); \ + GET_IR_NODE(scale1_out); \ + GET_IR_NODE(transpose2_op); \ + GET_IR_NODE(transpose2_out); \ + GET_IR_NODE(softmax1_op); \ + GET_IR_NODE(softmax1_out); \ + GET_IR_NODE(matmul2_op); \ + GET_IR_NODE(matmul2_out); \ + GET_IR_NODE(transpose3_op); \ + GET_IR_NODE(transpose3_out); \ + GET_IR_NODE(reshape2_op); \ + GET_IR_NODE(reshape2_out); + +namespace paddle { +namespace framework { +namespace ir { + +void VitAttentionFusePass::ApplyImpl(ir::Graph* graph) const { + GraphPatternDetector gpd; + const std::string pattern_name = "vit_attention_fuse"; + FusePassBase::Init(pattern_name, graph); + auto* scope = param_scope(); + + // pattern + std::unordered_set matmul_ops{"matmul", "matmul_v2"}; + PDNode* x = gpd.mutable_pattern() + ->NewNode("x") + ->assert_is_ops_input(matmul_ops, "X") + ->AsInput(); + patterns::VitAttention pattern(gpd.mutable_pattern(), pattern_name); + pattern(x); + + int fusion_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_NODES; + // do something; + OpDesc desc(matmul0_op->Op()->Block()); + desc.SetType("multihead_matmul"); + desc.SetInput("Input", {subgraph.at(x)->Name()}); + // refactor W and Bias + auto* w_tensor = + scope->FindVar(matmul0_in_y->Name())->GetMutable(); + auto w_dims = + phi::make_ddim({w_tensor->dims()[0], 3, w_tensor->dims()[1] / 3}); + w_tensor->Resize(w_dims); + + auto* b_tensor = + scope->FindVar(elementwise0_in_y->Name())->GetMutable(); + auto bias_dims = phi::make_ddim({3, b_tensor->dims()[0] / 3}); + b_tensor->Resize(bias_dims); + + desc.SetInput("W", {matmul0_in_y->Name()}); + desc.SetInput("Bias", {elementwise0_in_y->Name()}); + std::vector shape = softmax1_out->Var()->GetShape(); + desc.SetOutput("Out", {reshape2_out->Name()}); + desc.SetAttr("head_number", static_cast(shape[1])); + float alpha = PADDLE_GET_CONST(float, scale1_op->Op()->GetAttr("scale")); + desc.SetAttr("alpha", alpha); + + // Create a new node for the fused op. + auto vit_attention_node = graph->CreateOpNode(&desc); + + // Link inputs and outputs. + PADDLE_ENFORCE_NE( + subgraph.count(x), + 0, + platform::errors::NotFound("Detector did not find input x of conv2d.")); + + IR_NODE_LINK_TO(subgraph.at(x), vit_attention_node); // Input + IR_NODE_LINK_TO(matmul0_in_y, vit_attention_node); + IR_NODE_LINK_TO(elementwise0_in_y, vit_attention_node); + IR_NODE_LINK_TO(vit_attention_node, reshape2_out); // Output + + // Delete the unneeded nodes. + std::unordered_set marked_nodes( + {matmul0_op, matmul0_out, elementwise0_op, elementwise0_out, + reshape1_op, reshape1_out, transpose1_op, transpose1_out, + slice1_op, slice1_out, slice2_op, slice2_out, + slice3_op, slice3_out, matmul1_op, matmul1_out, + scale1_op, scale1_out, transpose2_op, transpose2_out, + softmax1_op, softmax1_out, matmul2_op, matmul2_out, + transpose3_op, transpose3_out, reshape2_op}); + + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + gpd(graph, handler); + AddStatis(fusion_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(vit_attention_fuse_pass, + paddle::framework::ir::VitAttentionFusePass); +REGISTER_PASS_CAPABILITY(vit_attention_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .GE("reshape2", 0) + .GE("transpose2", 0) + .GE("slice", 0) + .GE("scale", 0) + .GE("softmax", 0) + .GE("matmul_v2", 0)); diff --git a/paddle/fluid/framework/ir/vit_attention_fuse_pass.h b/paddle/fluid/framework/ir/vit_attention_fuse_pass.h new file mode 100644 index 00000000000..731c5c6243d --- /dev/null +++ b/paddle/fluid/framework/ir/vit_attention_fuse_pass.h @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +// Fusing of vit_attention structure + +class Graph; + +class VitAttentionFusePass : public FusePassBase { + public: + virtual ~VitAttentionFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index b2d11a52c0a..3bfccf11307 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -101,6 +101,7 @@ const std::vector kTRTSubgraphPasses({ "delete_c_identity_op_pass", // "trt_multihead_matmul_fuse_pass_v2", // "trt_multihead_matmul_fuse_pass_v3", // + "vit_attention_fuse_pass", // "trt_skip_layernorm_fuse_pass", // "preln_skip_layernorm_fuse_pass", // "preln_residual_bias_fuse_pass", // diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index ce678bd4915..a597a484f9e 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -29,6 +29,10 @@ class MultiheadMatMulOpConverter : public OpConverter { framework::OpDesc op_desc(op, nullptr); // Declare inputs auto* input = engine_->GetITensor(op_desc.Input("Input").front()); + auto input_dims = input->getDimensions(); + bool bias_qk_attr = + (op_desc.Inputs().find("BiasQK") == op_desc.Inputs().end()) ? false + : true; // fc weights and fc bias auto weight_name = op_desc.Input("W").front(); @@ -287,6 +291,234 @@ class MultiheadMatMulOpConverter : public OpConverter { plugin_inputs.data(), plugin_inputs.size(), *plugin); layer = plugin_layer; } + } + if (input_dims.d[1] <= 384 && !bias_qk_attr && + engine_->precision() != AnalysisConfig::Precision::kFloat32) { + /* + * input_dims.d[0]: batch(-1) + * input_dims.d[1]: length:256 + * input_dims.d[2]: hidden_size:768 + input + |[b,256,768] + | + shuffle weight bias + |[b,256,768,1,1] | | + |_____________________|_________| + | + fc + |[b,256,2304,1,1] + | + shuffle mask(fake) pos max_length + |[b*256,2304,1,1] | | | + | | | | + |_______________________|_________|________| + | + MHA + |[b*256,768] + | + shuffle + |[b, 256, 768] + | + out + */ + + nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT, + static_cast(weight_data), + static_cast(weight_t->numel())}; + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, + static_cast(bias_data), + static_cast(bias_t->numel())}; + + /*** transpose the weight and bias ***/ + int head_size = hidden_out / head_number; + // [3, head_number, head_size, hidden_in] -> [head_number, 3, + // head_size, hidden_in] + auto transpose_weight_v2 = [](const float* src, + float* dst, + int three, + int head_number, + int head_size, + int hidden_in) { + const int HH = head_size * hidden_in; + for (int i = 0; i < three; ++i) { + for (int n = 0; n < head_number; ++n) { + for (int hh = 0; hh < HH; ++hh) { + dst[n * three * HH + i * HH + hh] = + src[i * head_number * HH + n * HH + hh]; + } + } + } + }; + // [3, head_number, head_size] -> [head_number, 3, head_size] + auto transpose_bias_v2 = + [](const float* src, float* dst, int N, int H) { + for (int i = 0; i < 3; ++i) { + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H; ++h) { + dst[n * 3 * H + i * H + h] = src[i * N * H + n * H + h]; + } + } + } + }; + memcpy(weight_data_tmp.data(), + weight_data, + weight_t->numel() * sizeof(float)); + transpose_weight_v2(weight_data_tmp.data(), + weight_data, + three, + head_number, + head_size, + hidden_in); + + std::vector bias_data_tmp; + bias_data_tmp.reserve(bias_t->numel()); + memcpy( + bias_data_tmp.data(), bias_data, bias_t->numel() * sizeof(float)); + transpose_bias_v2( + bias_data_tmp.data(), bias_data, head_number, head_size); + + // add shuffle for FullyConnected layer + std::vector reshape_before_fc_shape_tensor; + nvinfer1::ITensor* input_shape_tensor = Shape(input); + for (int i = 0; i < 5; i++) { + reshape_before_fc_shape_tensor.push_back(Add1DConstantLayer(1)); + } + for (int i = 0; i < 3; i++) { + reshape_before_fc_shape_tensor[i] = + GetEleTensorOfShape(input_shape_tensor, i); + } + auto* reshape_before_fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + reshape_before_fc_layer->setInput( + 1, *Concat(reshape_before_fc_shape_tensor)); + reshape_before_fc_layer->setName( + ("shuffle_before_fc_multihead_matmul(Output: " + output_name + ")") + .c_str()); + + // add fc layer + nvinfer1::ILayer* fc_layer = nullptr; + fc_layer = TRT_ENGINE_ADD_LAYER(engine_, + FullyConnected, + *reshape_before_fc_layer->getOutput(0), + n, + weight, + bias); + + // add shuffle for CustomQKVToContextPluginDynamic layer + auto* reshape_after_fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *fc_layer->getOutput(0)); + std::vector mha_input_tensor_shape; + mha_input_tensor_shape.push_back(Add1DConstantLayer(-1)); + mha_input_tensor_shape.push_back( + Add1DConstantLayer(hidden_out * 3)); // Q,K,V + mha_input_tensor_shape.push_back(Add1DConstantLayer(1)); + mha_input_tensor_shape.push_back(Add1DConstantLayer(1)); + reshape_after_fc_layer->setInput(1, *Concat(mha_input_tensor_shape)); + reshape_after_fc_layer->setName( + ("shuffle_after_fc_multihead_matmul(Output: " + output_name + ")") + .c_str()); + + // add mha_plugin + auto creator = GetPluginRegistry()->getPluginCreator( + "CustomQKVToContextPluginDynamic", "2"); + assert(creator != nullptr); + // set the attributes of mha_plugin + int type = static_cast(nvinfer1::DataType::kHALF); + int var_seqlen = 1; + bool has_mask = true; + std::vector fields{ + {"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, 1}, + {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1}, + {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1}, + {"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1}, + {"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1}}; + nvinfer1::PluginFieldCollection* plugin_collection = + static_cast( + malloc(sizeof(*plugin_collection) + + fields.size() * + sizeof(nvinfer1::PluginField))); // remember to free + plugin_collection->nbFields = static_cast(fields.size()); + plugin_collection->fields = fields.data(); + auto plugin = creator->createPlugin("CustomQKVToContextPluginDynamic", + plugin_collection); + free(plugin_collection); + // set inputs + std::vector plugin_inputs; + // input_0 for plugin + plugin_inputs.emplace_back(reshape_after_fc_layer->getOutput(0)); + // input_1(fake) for plugin + std::vector mask = {1}; + nvinfer1::ITensor* mask_tensor = Add1DConstantLayer(mask); + plugin_inputs.emplace_back(mask_tensor); + // input_2 for plugin + std::vector pos_id = {0}; + int max_batch = 500; + for (int i = 1; i < max_batch; i++) { + pos_id.push_back(i); + } + nvinfer1::ITensor* fake_pos_id_tensor = Add1DConstantLayer(pos_id); + nvinfer1::ITensor* length_tensor = + GetEleTensorOfShape(input_shape_tensor, 1); + auto pos_id_layer = + TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *fake_pos_id_tensor, + *length_tensor, + nvinfer1::ElementWiseOperation::kPROD); + // size = batch + 1; + nvinfer1::ITensor* batch_tensor = + GetEleTensorOfShape(input_shape_tensor, 0); + std::vector const_data = {1}; + nvinfer1::ITensor* const_tensor = Add1DConstantLayer(const_data); + auto size_layer = + TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *batch_tensor, + *const_tensor, + nvinfer1::ElementWiseOperation::kSUM); + // get size(batch + 1) data from pos_id_tensor + nvinfer1::Dims start; + nvinfer1::Dims stride; + nvinfer1::Dims size; + + start.nbDims = 1; + stride.nbDims = 1; + size.nbDims = 1; + + start.d[0] = 0; + stride.d[0] = 1; + size.d[0] = 1; + + auto* slice_pos_layer = TRT_ENGINE_ADD_LAYER( + engine_, Slice, *pos_id_layer->getOutput(0), start, size, stride); + slice_pos_layer->setInput(2, *size_layer->getOutput(0)); + plugin_inputs.emplace_back(slice_pos_layer->getOutput(0)); + + // input_3 for plugin + std::vector data(500, 1); + nvinfer1::ITensor* fake_max_seqlen_tensor = Add1DConstantLayer(data); + auto* slice_max_layer = TRT_ENGINE_ADD_LAYER( + engine_, Slice, *fake_max_seqlen_tensor, start, size, stride); + slice_max_layer->setInput(2, *length_tensor); + plugin_inputs.emplace_back(slice_max_layer->getOutput(0)); + // plugin_layer + auto plugin_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + + // add shuffle + auto* reshape_after_mha_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0)); + std::vector reshape_tensor; + reshape_tensor.push_back(batch_tensor); + reshape_tensor.push_back(length_tensor); + reshape_tensor.push_back(Add1DConstantLayer(-1)); + reshape_after_mha_layer->setInput(1, *Concat(reshape_tensor)); + reshape_after_mha_layer->setName( + ("shuffle_last_multihead_matmul(Output: " + output_name + ")") + .c_str()); + + // return + layer = reshape_after_mha_layer; } else { PADDLE_ENFORCE_EQ( input->getDimensions().nbDims, diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 6ed6ba57075..910d0393167 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1820,24 +1820,32 @@ bool OpTeller::Tell(const framework::ir::Node* node, const auto input_shape = input_desc->GetShape(); const auto head_number = PADDLE_GET_CONST(int, desc.GetAttr("head_number")); - - auto* biasqk_desc = block->FindVar(desc.Input("BiasQK").front()); - const auto biasqk_shape = biasqk_desc->GetShape(); - // The BiasQK's shape requires to be - // [batch, 1, 1, length] or [batch, head, length, length]. - bool has_same_shape = head_number == biasqk_shape[1] && - input_shape[1] == biasqk_shape[2] && - input_shape[1] == biasqk_shape[3]; - bool is_broadcastable = biasqk_shape[1] == 1 && biasqk_shape[2] == 1 && + auto inputs = desc.Inputs(); + bool has_bias_qk = (inputs.find("BiasQK") == inputs.end()) ? false : true; + if (has_bias_qk) { + auto* biasqk_desc = block->FindVar(desc.Input("BiasQK").front()); + const auto biasqk_shape = biasqk_desc->GetShape(); + // The BiasQK's shape requires to be + // [batch, 1, 1, length] or [batch, head, length, length]. + bool has_same_shape = head_number == biasqk_shape[1] && + input_shape[1] == biasqk_shape[2] && input_shape[1] == biasqk_shape[3]; - if (!(has_same_shape || is_broadcastable)) { - VLOG(3) << "The BiasQK's shape is invalid, expect [" << input_shape[0] - << ", 1, 1, " << input_shape[1] << "] or [" << input_shape[0] - << ", " << head_number << ", " << input_shape[1] << ", " - << input_shape[1] << "] but [" << biasqk_shape[0] << ", " - << biasqk_shape[1] << ", " << biasqk_shape[2] << ", " - << biasqk_shape[3] << "]."; + bool is_broadcastable = biasqk_shape[1] == 1 && biasqk_shape[2] == 1 && + input_shape[1] == biasqk_shape[3]; + if (!(has_same_shape || is_broadcastable)) { + VLOG(3) << "The BiasQK's shape is invalid, expect [" << input_shape[0] + << ", 1, 1, " << input_shape[1] << "] or [" << input_shape[0] + << ", " << head_number << ", " << input_shape[1] << ", " + << input_shape[1] << "] but [" << biasqk_shape[0] << ", " + << biasqk_shape[1] << ", " << biasqk_shape[2] << ", " + << biasqk_shape[3] << "]."; + return false; + } + } else { +#if !IS_TRT_VERSION_GE(8000) + VLOG(3) << "The version of TRT must be greater than 8000"; return false; +#endif } } diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cc b/paddle/fluid/operators/fused/multihead_matmul_op.cc index 6982fadc75b..d263e885263 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cc +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cc @@ -40,11 +40,6 @@ class MultiHeadMatMulV2Op : public framework::OperatorWithKernel { true, platform::errors::InvalidArgument( "Input(Bias) of MultiHeadMatMul should not be null.")); - PADDLE_ENFORCE_EQ( - context->HasInput("BiasQK"), - true, - platform::errors::InvalidArgument( - "Input(BiasQK) of MultiHeadMatMul should not be null.")); PADDLE_ENFORCE_EQ( context->HasOutput("Out"), true, @@ -69,15 +64,6 @@ class MultiHeadMatMulV2Op : public framework::OperatorWithKernel { "%d-D tensor now.", dim_bias_q.size())); - auto dim_bias_qk = context->GetInputDim("BiasQK"); - PADDLE_ENFORCE_GT( - dim_bias_qk.size(), - 3, - platform::errors::InvalidArgument( - "Multihead input bias qk should be at least 4-D tensor, " - "but it's %d-D tensor now.", - dim_bias_qk.size())); - auto dim_input = context->GetInputDim("Input"); context->SetOutputDim("Out", dim_input); context->ShareLoD("Input", /*->*/ "Out"); @@ -90,7 +76,8 @@ class MultiHeadMatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Input", "The input of MultiHeadMatMul op"); AddInput("W", "The weight input of MultiHeadMatMul op"); AddInput("Bias", "The bias input of MultiHeadMatMul op"); - AddInput("BiasQK", "The QK bias input of MultiHeadMatMul op"); + AddInput("BiasQK", "The QK bias input of MultiHeadMatMul op") + .AsDispensable(); AddOutput("Out", "The output of MultiHeadMatMul op"); AddAttr("transpose_Q", R"DOC(If true, use the transpose of `Q`. diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cu b/paddle/fluid/operators/fused/multihead_matmul_op.cu index f2d010e16a2..16ab0d916d9 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cu +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cu @@ -264,15 +264,12 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { auto *input = context.Input("Input"); auto *w = context.Input("W"); auto *bias = context.Input("Bias"); - auto &bias_qk = GET_DATA_SAFELY(context.Input("BiasQK"), - "Input", - "BiasQK", - "MultiHeadMatMulV2"); + auto *bias_qk = context.Input("BiasQK"); auto *input_d = input->data(); auto *w_d = w->data(); auto *bias_d = bias->data(); - auto *bias_qk_d = bias_qk.template data(); + auto *bias_qk_d = bias_qk ? bias_qk->data() : nullptr; T scale = static_cast(context.Attr("alpha")); int head_number = context.Attr("head_number"); @@ -288,7 +285,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { int hidden = input_dims[2]; Tensor temp_bias_tensor; // if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted - if (bias_qk.numel() == (batch * seq_len)) { + if (bias_qk && bias_qk->numel() == (batch * seq_len)) { temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len}); auto *temp_qk_bias = temp_bias_tensor.mutable_data(context.GetPlace()); int grid = batch * head_number * seq_len; @@ -297,6 +294,17 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { bias_qk_d, temp_qk_bias, seq_len, head_number); bias_qk_d = static_cast(temp_qk_bias); } + if (!bias_qk) { + int size = batch * head_number * seq_len * seq_len; + temp_bias_tensor.Resize({size}); + auto *temp_qk_bias = temp_bias_tensor.mutable_data(context.GetPlace()); +#ifdef PADDLE_WITH_HIP + hipMemset(temp_qk_bias, 0, sizeof(float) * size); +#else + cudaMemset(temp_qk_bias, 0, sizeof(float) * size); +#endif + bias_qk_d = static_cast(temp_qk_bias); + } int all_head_size = w_dims[2]; int head_size = all_head_size / head_number; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py index 9fd60e5f3fe..cf8611e4d8b 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py @@ -848,5 +848,271 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest): yield program_config +class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): + + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + + def generate_input1(): + return np.zeros((1, 256, 768), dtype=np.float32) + + def generate_weight1(): + return np.random.rand(768, 2304).astype(np.float32) + + def generate_weight2(): + return np.random.rand(2304).astype(np.float32) + + ops_config = [{ + "op_type": "matmul_v2", + "op_inputs": { + "X": ["input_data1"], + "Y": ["matmul1_weight"] + }, + "op_outputs": { + "Out": ["matmul1_output"] + }, + "op_attrs": { + "trans_x": False, + "trans_y": False + } + }, { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["matmul1_output"], + "Y": ["elementwise_add1_weight"] + }, + "op_outputs": { + "Out": ["elementwise_add1_output"] + }, + "op_attrs": { + "Scale_out": 1.0, + "Scale_x": 1.0, + "Scale_y": 1.0, + "axis": 2 + } + }, { + "op_type": "reshape2", + "op_inputs": { + "X": ["elementwise_add1_output"], + }, + "op_outputs": { + "Out": ["reshape1_output"], + "XShape": ["reshape1_output_xshape"] + }, + "op_attrs": { + "shape": [-1, 256, 3, 12, 64] + } + }, { + "op_type": "transpose2", + "op_inputs": { + "X": ["reshape1_output"] + }, + "op_outputs": { + "Out": ["transpose1_output"], + "XShape": ["transpose1_output_xshape"] + }, + "op_attrs": { + "axis": [2, 0, 3, 1, 4], + "data_format": "AnyLayout" + } + }, { + "op_type": "slice", + "op_inputs": { + "Input": ["transpose1_output"], + }, + "op_outputs": { + "Out": ["slice1_output"] + }, + "op_attrs": { + "axes": [0], + "starts": [0], + "ends": [1], + "decrease_axis": [0], + "infer_flags": [1] + } + }, { + "op_type": "slice", + "op_inputs": { + "Input": ["transpose1_output"], + }, + "op_outputs": { + "Out": ["slice2_output"] + }, + "op_attrs": { + "axes": [0], + "starts": [1], + "ends": [2], + "decrease_axis": [0], + "infer_flags": [1] + } + }, { + "op_type": "slice", + "op_inputs": { + "Input": ["transpose1_output"], + }, + "op_outputs": { + "Out": ["slice3_output"] + }, + "op_attrs": { + "axes": [0], + "starts": [2], + "ends": [3], + "decrease_axis": [0], + "infer_flags": [1] + } + }, { + "op_type": "transpose2", + "op_inputs": { + "X": ["slice2_output"] + }, + "op_outputs": { + "Out": ["transpose2_output"], + }, + "op_attrs": { + "axis": [0, 1, 3, 2], + "data_format": "AnyLayout" + } + }, { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["slice1_output"], + "Y": ["transpose2_output"] + }, + "op_outputs": { + "Out": ["matmul2_output"] + }, + "op_attrs": { + "trans_x": False, + "trans_y": False + } + }, { + "op_type": "scale", + "op_inputs": { + "X": ["matmul2_output"], + }, + "op_outputs": { + "Out": ["scale_output"] + }, + "op_attrs": { + "scale": 0.125, + "bias": 0.0, + "bias_after_scale": True + } + }, { + "op_type": "softmax", + "op_inputs": { + "X": ["scale_output"] + }, + "op_outputs": { + "Out": ["softmax_output"] + }, + "op_attrs": { + "axis": -1, + "data_format": "AnyLayout" + } + }, { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["softmax_output"], + "Y": ["slice3_output"] + }, + "op_outputs": { + "Out": ["matmul3_output"] + }, + "op_attrs": { + "trans_x": False, + "trans_y": False + } + }, { + "op_type": "transpose2", + "op_inputs": { + "X": ["matmul3_output"] + }, + "op_outputs": { + "Out": ["transpose3_output"], + "XShape": ["transpose3_output_xshape"] + }, + "op_attrs": { + "axis": [0, 2, 1, 3], + "data_format": "AnyLayout" + } + }, { + "op_type": "reshape2", + "op_inputs": { + "X": ["transpose3_output"] + }, + "op_outputs": { + "Out": ["reshape2_output"], + "XShape": ["reshape2_output_xshape"] + }, + "op_attrs": { + "shape": [-1, 256, 768] + } + }] + + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "matmul1_weight": + TensorConfig(data_gen=partial(generate_weight1)), + "elementwise_add1_weight": + TensorConfig(data_gen=partial(generate_weight2)) + }, + inputs={ + "input_data1": TensorConfig(data_gen=partial(generate_input1)) + }, + outputs=["reshape2_output"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + + def generate_dynamic_shape(attrs): + # The last dim of input1 and input2 should be static. + self.dynamic_shape.min_input_shape = { + "input_data1": [1, 8, 768], + } + self.dynamic_shape.max_input_shape = { + "input_data1": [16, 512, 768], + } + self.dynamic_shape.opt_input_shape = { + "input_data1": [1, 197, 768], + } + + def clear_dynamic_shape(): + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + def generate_trt_nodes_num(): + ver = paddle_infer.get_trt_compile_version() + if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8000: + return 0, 3 + return 1, 2 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.workspace_size = 2013265920 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num(), (1e-3, + 1e-3) + + def add_skip_trt_case(self): + pass + + def test(self): + self.add_skip_trt_case() + self.run_test() + + if __name__ == "__main__": unittest.main() -- GitLab