diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt old mode 100755 new mode 100644 index 5e46fd92bf56facc21bb2c1ebb98b670d94a0a32..bafdd26cf59b3ec36552c01b3868340dfe8b3d5a --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -105,6 +105,7 @@ pass_library(simplify_with_basic_ops_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(skip_layernorm_fuse_pass base) pass_library(multihead_matmul_fuse_pass inference) +pass_library(multihead_matmul_roformer_fuse_pass inference) pass_library(fused_multi_transformer_encoder_pass inference) pass_library(fused_multi_transformer_decoder_pass inference) pass_library(fuse_multi_transformer_layer_pass inference) diff --git a/paddle/fluid/framework/ir/multihead_matmul_roformer_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_roformer_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..dcb7d1efa927073c135eb40fab0eb1823f1ed81a --- /dev/null +++ b/paddle/fluid/framework/ir/multihead_matmul_roformer_fuse_pass.cc @@ -0,0 +1,830 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/multihead_matmul_roformer_fuse_pass.h" + +#include + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +static void ReplaceOutputVar(Node* op, Node* old_var, Node* new_var) { + if (op->IsOp() && op->Op()) { + new_var->inputs.push_back(op); + for (size_t i = 0; i < op->outputs.size(); ++i) { + if (op->outputs[i] == old_var) { + op->outputs[i] = new_var; + op->Op()->RenameOutput(old_var->Name(), new_var->Name()); + } + } + } +} + +PDNode* MultiHeadMatmulRoformerPattern::operator()() { + std::unordered_set matmul_ops{"matmul", "matmul_v2"}; + auto* input0 = pattern->NewNode(input0_repr()); + input0->assert_is_ops_input(matmul_ops); + + auto* input_cos = pattern->NewNode(input_cos_repr()); + input_cos->assert_is_op_input("elementwise_mul", "Y"); + auto* input_sin = pattern->NewNode(input_sin_repr()); + input_sin->assert_is_op_input("elementwise_mul", "Y"); + // First path with scale + auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_ops(matmul_ops); + auto* mul0_w_var = pattern->NewNode(mul0_w_repr()) + ->AsInput() + ->assert_is_ops_input(matmul_ops, "Y"); + auto* mul0_out_var = + pattern->NewNode(mul0_out_repr())->assert_is_ops_output(matmul_ops); + + decltype(mul0) eltadd0; + decltype(mul0) eltadd0_b_var; + decltype(mul0) eltadd0_out_var; + + mul0_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + eltadd0 = pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add"); + eltadd0_b_var = pattern->NewNode(eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + + eltadd0_out_var = pattern->NewNode(eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd0_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_0 = + pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2"); + + auto* reshape2_0_out_var = + pattern->NewNode(reshape2_0_out_repr())->assert_is_op_output("reshape2"); + reshape2_0_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_0 = + pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); + auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_0_out_var->AsIntermediate() + ->assert_is_op_input("elementwise_mul", "X") + ->assert_is_op_input("split", "X"); + + auto* eltmul_cos_q = + pattern->NewNode(eltmul_cos_q_repr())->assert_is_op("elementwise_mul"); + auto* eltmul_cos_q_out_var = pattern->NewNode(eltmul_cos_q_out_repr()) + ->assert_is_op_output("elementwise_mul"); + eltmul_cos_q_out_var->AsIntermediate()->assert_is_op_input("elementwise_add", + "X"); + + auto* split_q = pattern->NewNode(split_q_repr())->assert_is_op("split"); + auto* split_q_out_var = + pattern->NewNode(split_q_out_repr())->assert_is_op_output("split"); + split_q_out_var->AsIntermediate()->assert_is_op_input("concat", "X"); + auto* concat_q = pattern->NewNode(concat_q_repr())->assert_is_op("concat"); + auto* concat_q_out_var = + pattern->NewNode(concat_q_out_repr())->assert_is_op_output("concat"); + concat_q_out_var->AsIntermediate()->assert_is_op_input("elementwise_mul", + "X"); + + auto* eltmul_sin_q = + pattern->NewNode(eltmul_sin_q_repr())->assert_is_op("elementwise_mul"); + auto* eltmul_sin_q_out_var = pattern->NewNode(eltmul_sin_q_out_repr()) + ->assert_is_op_output("elementwise_mul"); + eltmul_sin_q_out_var->AsIntermediate()->assert_is_op_input("elementwise_add", + "Y"); + + auto* eltadd_q = + pattern->NewNode(eltadd_q_repr())->assert_is_op("elementwise_add"); + auto* eltadd_q_out_var = pattern->NewNode(eltadd_q_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd_q_out_var->AsIntermediate()->assert_is_op_input("scale"); + + auto* matmul_qk = + pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops); + auto* matmul_qk_out_var = + pattern->NewNode(matmul_qk_out_repr())->assert_is_ops_output(matmul_ops); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + auto* eltadd_qk = + pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); + auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd_qk_out_var->AsIntermediate()->assert_is_op_input("softmax"); + + auto* softmax_qk = + pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); + auto* softmax_qk_out_var = + pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax"); + softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops); + + auto* matmul_qkv = + pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops); + auto* matmul_qkv_out_var = + pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops); + matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_qkv = + pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2"); + auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_qkv = + pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); + auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr()) + ->assert_is_op_output("reshape2"); + reshape2_qkv_out_var->assert_is_ops_input(matmul_ops); + + auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale"); + auto* scale_out_var = + pattern->NewNode(scale_out_repr())->assert_is_op_output("scale"); + scale_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops); + + // Second path to matmul + auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_ops(matmul_ops); + auto* mul1_w_var = pattern->NewNode(mul1_w_repr()) + ->AsInput() + ->assert_is_ops_input(matmul_ops, "Y"); + auto* mul1_out_var = + pattern->NewNode(mul1_out_repr())->assert_is_ops_output(matmul_ops); + + decltype(mul1) eltadd1; + decltype(mul1) eltadd1_b_var; + decltype(mul1) eltadd1_out_var; + + mul1_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + eltadd1 = pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add"); + eltadd1_b_var = pattern->NewNode(eltadd1_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + + eltadd1_out_var = pattern->NewNode(eltadd1_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd1_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_1 = + pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2"); + + auto* reshape2_1_out_var = + pattern->NewNode(reshape2_1_out_repr())->assert_is_op_output("reshape2"); + reshape2_1_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_1 = + pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2"); + auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_1_out_var->AsIntermediate() + ->assert_is_op_input("elementwise_mul", "X") + ->assert_is_op_input("split", "X"); // link to matmul qk + + auto* eltmul_cos_k = + pattern->NewNode(eltmul_cos_k_repr())->assert_is_op("elementwise_mul"); + auto* eltmul_cos_k_out_var = pattern->NewNode(eltmul_cos_k_out_repr()) + ->assert_is_op_output("elementwise_mul"); + eltmul_cos_k_out_var->AsIntermediate()->assert_is_op_input("elementwise_add", + "X"); + + auto* split_k = pattern->NewNode(split_k_repr())->assert_is_op("split"); + auto* split_k_out_var = + pattern->NewNode(split_k_out_repr())->assert_is_op_output("split"); + split_k_out_var->AsIntermediate()->assert_is_op_input("concat", "X"); + auto* concat_k = pattern->NewNode(concat_k_repr())->assert_is_op("concat"); + auto* concat_k_out_var = + pattern->NewNode(concat_k_out_repr())->assert_is_op_output("concat"); + concat_k_out_var->AsIntermediate()->assert_is_op_input("elementwise_mul", + "X"); + + auto* eltmul_sin_k = + pattern->NewNode(eltmul_sin_k_repr())->assert_is_op("elementwise_mul"); + auto* eltmul_sin_k_out_var = pattern->NewNode(eltmul_sin_k_out_repr()) + ->assert_is_op_output("elementwise_mul"); + eltmul_sin_k_out_var->AsIntermediate()->assert_is_op_input("elementwise_add", + "Y"); + + auto* eltadd_k = + pattern->NewNode(eltadd_k_repr())->assert_is_op("elementwise_add"); + auto* eltadd_k_out_var = pattern->NewNode(eltadd_k_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd_k_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops); + + // Third path to matmul + auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_ops(matmul_ops); + auto* mul2_w_var = pattern->NewNode(mul2_w_repr()) + ->AsInput() + ->assert_is_ops_input(matmul_ops, "Y"); + auto* mul2_out_var = + pattern->NewNode(mul2_out_repr())->assert_is_ops_output(matmul_ops); + + decltype(mul2) eltadd2; + decltype(mul2) eltadd2_b_var; + decltype(mul2) eltadd2_out_var; + + mul2_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + eltadd2 = pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add"); + eltadd2_b_var = pattern->NewNode(eltadd2_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + + eltadd2_out_var = pattern->NewNode(eltadd2_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd2_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_2 = + pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2"); + + auto* reshape2_2_out_var = + pattern->NewNode(reshape2_2_out_repr())->assert_is_op_output("reshape2"); + reshape2_2_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_2 = + pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2"); + auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_2_out_var->AsIntermediate()->assert_is_ops_input( + matmul_ops); // link to matmul qkv + + // Q path + mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var}); + eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var}); + + reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); + transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); + split_q->LinksFrom({transpose2_0_out_var}).LinksTo({split_q_out_var}); + concat_q->LinksFrom({split_q_out_var}).LinksTo({concat_q_out_var}); + eltmul_sin_q->LinksFrom({concat_q_out_var, input_sin}) + .LinksTo({eltmul_sin_q_out_var}); + eltmul_cos_q->LinksFrom({transpose2_0_out_var, input_cos}) + .LinksTo({eltmul_cos_q_out_var}); + eltadd_q->LinksFrom({eltmul_cos_q_out_var, eltmul_sin_q_out_var}) + .LinksTo({eltadd_q_out_var}); + scale->LinksFrom({eltadd_q_out_var}).LinksTo({scale_out_var}); + // K path + mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var}); + eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var}); + reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var}); + transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var}); + split_k->LinksFrom({transpose2_1_out_var}).LinksTo({split_k_out_var}); + concat_k->LinksFrom({split_k_out_var}).LinksTo({concat_k_out_var}); + eltmul_sin_k->LinksFrom({concat_k_out_var, input_sin}) + .LinksTo({eltmul_sin_k_out_var}); + eltmul_cos_k->LinksFrom({transpose2_1_out_var, input_cos}) + .LinksTo({eltmul_cos_k_out_var}); + eltadd_k->LinksFrom({eltmul_cos_k_out_var, eltmul_sin_k_out_var}) + .LinksTo({eltadd_k_out_var}); + + // compute q*k + matmul_qk->LinksFrom({scale_out_var, eltadd_k_out_var}) + .LinksTo({matmul_qk_out_var}); + eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) + .LinksTo({eltadd_qk_out_var}); + softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); + // V path + mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var}); + eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var}); + reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var}); + transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var}); + // compute q*k*v + matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var}) + .LinksTo({matmul_qkv_out_var}); + transpose2_qkv->LinksFrom({matmul_qkv_out_var}) + .LinksTo({transpose2_qkv_out_var}); + reshape2_qkv->LinksFrom({transpose2_qkv_out_var}) + .LinksTo({reshape2_qkv_out_var}); + + return transpose2_2_out_var; +} +} // namespace patterns + +MultiHeadMatmulRoformerFusePass::MultiHeadMatmulRoformerFusePass() { + AddOpCompat(OpCompat("mul")) + .AddInput("X") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddInput("Y") // the shape shoule be (N*H, N*H) + .IsTensor() + .End() + .AddOutput("Out") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(2) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + // in bias, shape is (B, S, N*H), + // in biasqk, shape is (B, H, S, S) + .IsTensor() + .End() + .AddInput("Y") + // in bias, shape is (N*H) + // in biasqk, shape is (B, H, S, S) + .IsTensor() + .End() + // in bias, shape is (B, S, N*H) + // in biasqk, shape is (B, H, S, S) + .AddOutput("Out") + .IsTensor() + .End() + // in bias, it equal to 2 + // in biasqk, it equal to -1 or 0 + .AddAttr("axis") + .IsIntIn({2, -1, 0}) + .End(); + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H) + .IsType>() + .End(); + + // -->: (B, S, H, N) -> (B, H, S, N) + // <--: (B, H, S, N) -> (B, S, H, N) + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") // {0, 2, 1, 3} + .IsType>() + .End(); + + // QK (B, H, S, N)*(B, H, S, N) -> (B, H, S, S) + // QKV (B, H, S, S)*(B, H, S, N) -> (B, H, S, N) + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsType() // QK(anyvalue, will copy to new op) QKV(1.0) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") // QK(true) QKV(false) + .IsType() + .End(); + + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsBoolEQ(false) + .End() + .AddAttr("trans_y") // QK(true) QKV(false) + .IsType() + .End(); + + AddOpCompat(OpCompat("softmax")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3 + .End(); +} + +int MultiHeadMatmulRoformerFusePass::BuildFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + // Create pattern. + patterns::MultiHeadMatmulRoformerPattern multihead_pattern(pattern, + name_scope); + + multihead_pattern(); + // Create New OpDesc + auto fuse_creater = [&](Node* input0, + Node* input_cos, + Node* input_sin, + Node* mul0, + Node* mul1, + Node* mul2, + Node* mul0_out, + Node* mul1_out, + Node* mul2_out, + Node* mul0_w, + Node* mul1_w, + Node* mul2_w, + Node* eltadd0_b, + Node* eltadd1_b, + Node* eltadd2_b, + Node* eltadd_qk_b, + Node* reshape2, + Node* reshape2_qkv_out, + Node* scale, + Node* scale_out, + Node* matmul_qk) { + auto scale_attr = PADDLE_GET_CONST(float, scale->Op()->GetAttr("scale")); + + // mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H) + // bias (B * S * 3 * N * H) + bias (3 * N * H) + // Transpose (B * S * 3 * N * H) -> (3 * B * N * S * H) + auto* wq_tensor = + scope->FindVar(mul0_w->Name())->GetMutable(); + auto* wk_tensor = + scope->FindVar(mul1_w->Name())->GetMutable(); + auto* wv_tensor = + scope->FindVar(mul2_w->Name())->GetMutable(); + + auto* bq_tensor = + scope->FindVar(eltadd0_b->Name())->GetMutable(); + auto* bk_tensor = + scope->FindVar(eltadd1_b->Name())->GetMutable(); + auto* bv_tensor = + scope->FindVar(eltadd2_b->Name())->GetMutable(); + + auto* wq_data = wq_tensor->mutable_data(platform::CPUPlace()); + auto* wk_data = wk_tensor->mutable_data(platform::CPUPlace()); + auto* wv_data = wv_tensor->mutable_data(platform::CPUPlace()); + auto* bq_data = bq_tensor->mutable_data(platform::CPUPlace()); + auto* bk_data = bk_tensor->mutable_data(platform::CPUPlace()); + auto* bv_data = bv_tensor->mutable_data(platform::CPUPlace()); + + auto combined_w_dims = + phi::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); + auto combined_bias_dims = phi::make_ddim({3, bq_tensor->dims()[0]}); + + // reuse the mul0_w and eltadd_0_b nodes for the combined nodes. + auto* combined_w_desc = mul0_w->Var(); + combined_w_desc->SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); + combined_w_desc->SetPersistable(true); + + auto* combined_bias_desc = eltadd0_b->Var(); + combined_bias_desc->SetShape({3, bq_tensor->dims()[0]}); + combined_bias_desc->SetPersistable(true); + + phi::DenseTensor tmp_combined_w_tensor; + tmp_combined_w_tensor.Resize(combined_w_dims); + auto* tmp_combined_w_data = + tmp_combined_w_tensor.mutable_data(platform::CPUPlace()); + + std::vector w_vec = {wq_data, wk_data, wv_data}; + int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2]; + // Combine the three fc weights together. + for (int i = 0; i < dims_h; i++) { + for (int j = 0; j < 3; j++) { + for (int k = 0; k < dims_w; k++) { + int out_index = i * (3 * dims_w) + j * dims_w + k; + int in_index = i * dims_w + k; + tmp_combined_w_data[out_index] = w_vec[j][in_index]; + } + } + } + + wq_tensor->Resize(combined_w_dims); + auto* new_combined_w_data = + wq_tensor->mutable_data(platform::CPUPlace()); + memcpy(new_combined_w_data, + tmp_combined_w_data, + sizeof(float) * wq_tensor->numel()); + + scope->EraseVars({mul1_w->Name(), mul2_w->Name()}); + + phi::DenseTensor tmp_combined_bias_tensor; + tmp_combined_bias_tensor.Resize(combined_bias_dims); + auto* tmp_combined_bias_data = + tmp_combined_bias_tensor.mutable_data(platform::CPUPlace()); + + size_t bias_size = bq_tensor->numel(); + memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size); + memcpy( + tmp_combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size); + memcpy(tmp_combined_bias_data + 2 * bias_size, + bv_data, + sizeof(float) * bias_size); + + bq_tensor->Resize(combined_bias_dims); + auto* new_combined_bias_data = + bq_tensor->mutable_data(platform::CPUPlace()); + memcpy(new_combined_bias_data, + tmp_combined_bias_data, + sizeof(float) * bq_tensor->numel()); + + scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()}); + + auto reshape_desc = reshape2->Op(); + int head_number = + PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) + .at(2); + + OpDesc multihead_op_desc(mul0->Op()->Block()); + multihead_op_desc.SetType("multihead_matmul_roformer"); + + multihead_op_desc.SetInput("Input", {input0->Name()}); + multihead_op_desc.SetInput("Input_cos", {input_cos->Name()}); + multihead_op_desc.SetInput("Input_sin", {input_sin->Name()}); + multihead_op_desc.SetInput("W", {mul0_w->Name()}); + multihead_op_desc.SetInput("Bias", {eltadd0_b->Name()}); + multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()}); + + multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()}); + multihead_op_desc.SetAttr("alpha", scale_attr); + multihead_op_desc.SetAttr("head_number", head_number); + + auto* multihead = graph->CreateOpNode(&multihead_op_desc); + + IR_NODE_LINK_TO(input0, multihead); + IR_NODE_LINK_TO(input_cos, multihead); + IR_NODE_LINK_TO(input_sin, multihead); + IR_NODE_LINK_TO(mul0_w, multihead); + IR_NODE_LINK_TO(eltadd0_b, multihead); + IR_NODE_LINK_TO(eltadd_qk_b, multihead); + + IR_NODE_LINK_TO(multihead, reshape2_qkv_out); + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(input_cos, input_cos, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(input_sin, input_sin, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_0_out, reshape2_0_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_0_out, transpose2_0_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltmul_cos_q, eltmul_cos_q, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltmul_cos_q_out, eltmul_cos_q_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltmul_sin_q, eltmul_sin_q, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltmul_sin_q_out, eltmul_sin_q_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(split_q, split_q, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(split_q_out, split_q_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(concat_q, concat_q, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(concat_q_out, concat_q_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_q, eltadd_q, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_q_out, eltadd_q_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_1_out, reshape2_1_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_1_out, transpose2_1_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltmul_cos_k, eltmul_cos_k, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltmul_cos_k_out, eltmul_cos_k_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltmul_sin_k, eltmul_sin_k, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltmul_sin_k_out, eltmul_sin_k_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(split_k, split_k, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(split_k_out, split_k_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(concat_k, concat_k, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(concat_k_out, concat_k_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_k, eltadd_k, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_k_out, eltadd_k_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_2_out, reshape2_2_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_2_out, transpose2_2_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern); + // nodes need be removed + GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd0_b, eltadd0_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd0_out, eltadd0_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd1, eltadd1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd1_b, eltadd1_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd1_out, eltadd1_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd2, eltadd2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd2_b, eltadd2_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd2_out, eltadd2_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk, matmul_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk_out, matmul_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk, eltadd_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_b, eltadd_qk_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + softmax_qk_out, softmax_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qkv_out, matmul_qkv_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_qkv_out, reshape2_qkv_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_qkv, transpose2_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_qkv_out, transpose2_qkv_out, multihead_pattern); + + // If weights or biases in qkv's fc are shared by multiple multihead_matmul + // patterns, we do not support this kind of fusion, this pass will not take + // effect. + bool is_fc_params_shared = + mul0_w->outputs.size() > 1 || mul1_w->outputs.size() > 1 || + mul2_w->outputs.size() > 1 || eltadd0_b->outputs.size() > 1 || + eltadd1_b->outputs.size() > 1 || eltadd2_b->outputs.size() > 1; + if (is_fc_params_shared) { + return; + } + fuse_creater(input0, + input_cos, + input_sin, + mul0, + mul1, + mul2, + mul0_out, + mul1_out, + mul2_out, + mul0_w, + mul1_w, + mul2_w, + eltadd0_b, + eltadd1_b, + eltadd2_b, + eltadd_qk_b, + reshape2_0, + reshape2_qkv_out, + scale, + scale_out, + matmul_qk); + + std::unordered_set marked_nodes({eltadd0, + eltadd1, + eltadd2, + eltadd1_b, + eltadd2_b, + eltadd0_out, + eltadd1_out, + eltadd2_out, + reshape2_0, + reshape2_1, + reshape2_2, + reshape2_0_out, + reshape2_1_out, + reshape2_2_out, + transpose2_0, + transpose2_1, + transpose2_2, + transpose2_0_out, + transpose2_1_out, + transpose2_2_out, + eltmul_cos_q, + eltmul_cos_q_out, + eltmul_sin_q, + eltmul_sin_q_out, + eltmul_cos_k, + eltmul_cos_k_out, + eltmul_sin_k, + eltmul_sin_k_out, + split_q, + split_q_out, + concat_q, + concat_q_out, + split_k, + split_k_out, + concat_k, + concat_k_out, + eltadd_q, + eltadd_q_out, + eltadd_k, + eltadd_k_out, + matmul_qk, + matmul_qk_out, + eltadd_qk, + eltadd_qk_out, + softmax_qk, + softmax_qk_out, + transpose2_qkv, + transpose2_qkv_out, + matmul_qkv, + matmul_qkv_out, + mul0, + mul1, + mul2, + mul0_out, + mul1_out, + mul2_out, + mul1_w, + mul2_w, + reshape2_qkv, + scale}); + // Remove unneeded nodes. + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + gpd(graph, handler); + + return fusion_count; +} + +void MultiHeadMatmulRoformerFusePass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::Fatal( + "During the multiheadMatmul pass, The scope should not be null.")); + + int fusion_count = BuildFusion(graph, name_scope_, scope); + if (fusion_count > 0) { + graph->Set(kMultiheadMatmulPass, new bool(true)); + } + AddStatis(fusion_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(multihead_matmul_roformer_fuse_pass, + paddle::framework::ir::MultiHeadMatmulRoformerFusePass); + +REGISTER_PASS_CAPABILITY(multihead_matmul_roformer_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("elementwise_add", 1) + .EQ("reshape2", 0) + .EQ("transpose2", 0) + .EQ("scale", 0) + .LE("matmul", 1) + .EQ("matmul_v2", 0) + .EQ("softmax", 0)); diff --git a/paddle/fluid/framework/ir/multihead_matmul_roformer_fuse_pass.h b/paddle/fluid/framework/ir/multihead_matmul_roformer_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..4d081e7c3ac780a93d1e82ea722227458c9d0808 --- /dev/null +++ b/paddle/fluid/framework/ir/multihead_matmul_roformer_fuse_pass.h @@ -0,0 +1,176 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { +/* + * \brief Fuse the subgraph representing multihead attention part of roformer + * into multihead_matmul_roformer op. + * + * \note The following graph represents this equation: + * + * x - input data + * cos - input data of cos mat + * sin - input data of sin mat + * ele_add - elementwise_add + * ele_mul - elementwise_mul + * + * x + * / | \ + * / | \ + * / | \ + * | | | + * | | | + * mul mul mul + * | | | + * ele_add ele_add ele_add + * | | | + * reshape2 reshape2 reshape2 + * | | | + * transpose2 transpose2 transpose2 + * | / \ / \ + * | | | | | + * | | cos split | sin split + * | | / | | / | + * | ele_mul concat ele_mul concat + * | | | | | + * | \ / \ / + * | ele_add ele_add + * | | | + * | | scale + * | | | + * | \ / + * | matmul + * | | + * | ele_add + * \ | + * \ softmax + * \ | + * \ / + * matmmul + * + */ + +struct MultiHeadMatmulRoformerPattern : public PatternBase { + MultiHeadMatmulRoformerPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, "multihead_matmul_roformer") {} + + PDNode* operator()(); + + // declare operator node's name + PATTERN_DECL_NODE(input0); + PATTERN_DECL_NODE(input_cos); + PATTERN_DECL_NODE(input_sin); + PATTERN_DECL_NODE(mul0); + PATTERN_DECL_NODE(mul1); + PATTERN_DECL_NODE(mul2); + PATTERN_DECL_NODE(mul0_w); + PATTERN_DECL_NODE(mul1_w); + PATTERN_DECL_NODE(mul2_w); + PATTERN_DECL_NODE(mul0_out); + PATTERN_DECL_NODE(mul1_out); + PATTERN_DECL_NODE(mul2_out); + PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_out); + PATTERN_DECL_NODE(eltadd1_out); + PATTERN_DECL_NODE(eltadd2_out); + PATTERN_DECL_NODE(reshape2_0); + PATTERN_DECL_NODE(reshape2_1); + PATTERN_DECL_NODE(reshape2_2); + PATTERN_DECL_NODE(reshape2_qkv); + PATTERN_DECL_NODE(reshape2_0_out); + PATTERN_DECL_NODE(reshape2_1_out); + PATTERN_DECL_NODE(reshape2_2_out); + PATTERN_DECL_NODE(reshape2_qkv_out); + PATTERN_DECL_NODE(scale); + PATTERN_DECL_NODE(scale_out); + PATTERN_DECL_NODE(transpose2_0); + PATTERN_DECL_NODE(transpose2_1); + PATTERN_DECL_NODE(transpose2_2); + PATTERN_DECL_NODE(transpose2_qkv); + PATTERN_DECL_NODE(transpose2_0_out); + PATTERN_DECL_NODE(transpose2_1_out); + PATTERN_DECL_NODE(transpose2_2_out); + PATTERN_DECL_NODE(transpose2_qkv_out); + + PATTERN_DECL_NODE(eltmul_cos_q); + PATTERN_DECL_NODE(eltmul_cos_q_out); + PATTERN_DECL_NODE(eltmul_sin_q); + PATTERN_DECL_NODE(eltmul_sin_q_out); + PATTERN_DECL_NODE(eltmul_cos_k); + PATTERN_DECL_NODE(eltmul_cos_k_out); + PATTERN_DECL_NODE(eltmul_sin_k); + PATTERN_DECL_NODE(eltmul_sin_k_out); + + PATTERN_DECL_NODE(split_q); + PATTERN_DECL_NODE(split_q_out); + PATTERN_DECL_NODE(concat_q); + PATTERN_DECL_NODE(concat_q_out); + PATTERN_DECL_NODE(split_k); + PATTERN_DECL_NODE(split_k_out); + PATTERN_DECL_NODE(concat_k); + PATTERN_DECL_NODE(concat_k_out); + + PATTERN_DECL_NODE(eltadd_q); + PATTERN_DECL_NODE(eltadd_q_out); + PATTERN_DECL_NODE(eltadd_k); + PATTERN_DECL_NODE(eltadd_k_out); + + PATTERN_DECL_NODE(matmul_qk); + PATTERN_DECL_NODE(matmul_qk_out); + PATTERN_DECL_NODE(eltadd_qk); + PATTERN_DECL_NODE(eltadd_qk_b); + PATTERN_DECL_NODE(eltadd_qk_out); + PATTERN_DECL_NODE(softmax_qk); + PATTERN_DECL_NODE(softmax_qk_out); + + PATTERN_DECL_NODE(matmul_qkv); + PATTERN_DECL_NODE(matmul_qkv_out); +}; + +} // namespace patterns + +class MultiHeadMatmulRoformerFusePass : public FusePassBase { + public: + MultiHeadMatmulRoformerFusePass(); + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{"multihead_matmul_roformer_fuse"}; + + private: + int BuildFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index d9adc9426b44729ec23c324966d3e3ec2621a79a..dcfa0951d3d27c1360f7f9116bf37e60e3c5c142 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2265,6 +2265,7 @@ USE_TRT_CONVERTER(instance_norm); USE_TRT_CONVERTER(layer_norm); USE_TRT_CONVERTER(gelu); USE_TRT_CONVERTER(multihead_matmul); +USE_TRT_CONVERTER(multihead_matmul_roformer); USE_TRT_CONVERTER(skip_layernorm); USE_TRT_CONVERTER(slice); USE_TRT_CONVERTER(scale); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index aa699b445839722edd2dc3c8069dbe87cd3ad297..e6dc79b509dab96e45c07f5aa6a6a34777ebd0fb 100755 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -106,6 +106,7 @@ const std::vector kTRTSubgraphPasses({ "delete_c_identity_op_pass", // "trt_multihead_matmul_fuse_pass_v2", // "trt_multihead_matmul_fuse_pass_v3", // + "multihead_matmul_roformer_fuse_pass", // "constant_folding_pass", // "vit_attention_fuse_pass", // "trt_skip_layernorm_fuse_pass", // diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 849bccc3c291d8ae081a7610cdd096d0a92a8057..3a9a7527db69c538c3f9ef872be332820e783b85 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -23,6 +23,7 @@ list( gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc + multihead_matmul_roformer_op.cc shuffle_channel_op.cc swish_op.cc silu_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_roformer_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_roformer_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..879c8fa9d6a5de6ab8cb4155225f0b5944e6592c --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_roformer_op.cc @@ -0,0 +1,206 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See +the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/multihead_matmul_roformer_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class MultiheadMatMulRoformerOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(3) << "convert a fluid multihead_mamul_roformer op to a corresponding " + "tensorrt " + "network structure"; + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("Input").front()); + auto* input_cos = engine_->GetITensor(op_desc.Input("Input_cos").front()); + auto* input_sin = engine_->GetITensor(op_desc.Input("Input_sin").front()); + // fc weights and fc bias + auto weight_name = op_desc.Input("W").front(); + auto bias_name = op_desc.Input("Bias").front(); + + auto* weight_v = scope.FindVar(weight_name); + auto* weight_t = weight_v->GetMutable(); + + auto* bias_v = scope.FindVar(bias_name); + auto* bias_t = bias_v->GetMutable(); + + float* weight_data = nullptr; + float in_scale = 0.; + + if (op_desc.HasAttr("Input_scale")) { + in_scale = PADDLE_GET_CONST(float, op_desc.GetAttr("Input_scale")); + engine_->SetTensorDynamicRange(input, in_scale); + } + weight_data = const_cast(static_cast( + engine_->GetFp32TrtWeight(weight_name, *weight_t).get().values)); + + float* bias_data = const_cast(static_cast( + engine_->GetFp32TrtWeight(bias_name, *bias_t).get().values)); + std::vector weight_data_tmp; + weight_data_tmp.reserve(weight_t->numel()); + memcpy( + weight_data_tmp.data(), weight_data, weight_t->numel() * sizeof(float)); + + // (hidden_in, 3, hidden_out) + auto weight_dims = weight_t->dims(); + + int hidden_in = weight_dims[0]; // channels_in + int three = weight_dims[1]; // channels_out + int hidden_out = weight_dims[2]; // channels_out + int m = hidden_in; + int n = three * hidden_out; + auto tranpose_weight = [](const float* src, float* dst, int m, int n) { + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + dst[j * m + i] = src[i * n + j]; + } + } + }; + tranpose_weight(weight_data_tmp.data(), weight_data, m, n); + + int head_number = PADDLE_GET_CONST(int, op_desc.GetAttr("head_number")); + + nvinfer1::ILayer* layer = nullptr; + auto output_name = op_desc.Output("Out")[0]; + bool flag_varseqlen = engine_->use_varseqlen() && + engine_->tensorrt_transformer_posid() != "" && + engine_->tensorrt_transformer_maskid() != ""; + + if (engine_->with_dynamic_shape()) { + if (flag_varseqlen) { + PADDLE_THROW( + platform::errors::Fatal("roformer not support varseqlen yet")); + } else { + PADDLE_ENFORCE_EQ( + input->getDimensions().nbDims, + 3, + platform::errors::InvalidArgument( + "The Input dim of the MultiheadMatMul should be 3, " + "but it's (%d) now.", + input->getDimensions().nbDims)); + // transpose weight_data from m * n to n * m + auto* input_bias_qk = + engine_->GetITensor(op_desc.Input("BiasQK").front()); + + TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, + static_cast(weight_data), + static_cast(weight_t->numel())}; + weight.dims.assign({n, m}); + + TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, + static_cast(bias_data), + static_cast(bias_t->numel())}; + + // add shuffle before fc + nvinfer1::Dims reshape_before_fc_dim; + reshape_before_fc_dim.nbDims = 5; + reshape_before_fc_dim.d[0] = 0; + reshape_before_fc_dim.d[1] = 0; + reshape_before_fc_dim.d[2] = 0; + reshape_before_fc_dim.d[3] = 1; + reshape_before_fc_dim.d[4] = 1; + auto* reshape_before_fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + if (op_desc.HasAttr("Input_scale")) { + engine_->SetTensorDynamicRange(reshape_before_fc_layer->getOutput(0), + in_scale); + } + reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim); + reshape_before_fc_layer->setName( + ("shuffle_before_multihead_matmul(Output: " + output_name + ")") + .c_str()); + + // add layer fc + nvinfer1::ILayer* fc_layer = nullptr; + if (op_desc.HasAttr("Input_scale")) { + nvinfer1::DimsHW nv_ksize(1, 1); + fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, + Convolution, + *reshape_before_fc_layer->getOutput(0), + n, + nv_ksize, + weight.get(), + bias.get()); + } else { + fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, + FullyConnected, + *reshape_before_fc_layer->getOutput(0), + n, + weight.get(), + bias.get()); + } + + if (op_desc.HasAttr("fc_out_threshold")) { + PADDLE_ENFORCE_EQ( + op_desc.HasAttr("fc_out_threshold"), + true, + platform::errors::InvalidArgument( + "must have out threshold in multihead layers in int8 mode")); + float out_scale = + PADDLE_GET_CONST(float, op_desc.GetAttr("fc_out_threshold")); + engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale); + } + fc_layer->setName( + ("multihead_matmul_fc(Output: " + output_name + ")").c_str()); + + // no need to add shuffle after fc, just change it in + // QkvToContextPluginDynamic + + // add qkv to context + int head_size = hidden_out / head_number; + float scale = PADDLE_GET_CONST(float, op_desc.GetAttr("alpha")); + + std::vector plugin_inputs; + plugin_inputs.push_back(fc_layer->getOutput(0)); + plugin_inputs.push_back(input_cos); + plugin_inputs.push_back(input_sin); + plugin_inputs.push_back(input_bias_qk); + bool with_fp16 = + engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + + if (engine_->precision() == AnalysisConfig::Precision::kInt8) { + with_fp16 = true; + } + plugin::DynamicPluginTensorRT* plugin = + new plugin::MultiheadMatmulRoformerPlugin( + hidden_in, head_number, head_size, scale, with_fp16); + layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 4, plugin); + } + } else { + PADDLE_THROW(platform::errors::Fatal( + "You are running the Ernie(Bert) model in static shape mode, which " + "is not supported for the time being.\n" + "You can use the config.SetTRTDynamicShapeInfo(...) interface to set " + "the shape information to run the dynamic shape mode.")); + } + RreplenishLayerAndOutput( + layer, "multihead_matmul_roformer", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(multihead_matmul_roformer, + MultiheadMatMulRoformerOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 75bc402deb4cde7012df59ceaabbb81d69919503..32297317df806b686e74aed47cbfa760d8fdb7a0 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1723,6 +1723,58 @@ struct SimpleOpTypeSetTeller : public Teller { } } + if (op_type == "multihead_matmul_roformer") { + if (!with_dynamic_shape) { + VLOG(3) << "the multihead_matmul_roformer does not support static " + "shape yet"; + return false; + } + + if (desc.HasAttr("enable_int8") && !desc.HasAttr("Input_scale")) { + VLOG(3) << "Multihead layers must have input scale in int8 mode."; + return false; + } + + auto* block = desc.Block(); + if (block == nullptr) { + VLOG(3) << "The block desc is nullptr, we can't continue to analyze. " + "Developers need to check whether block_desc is passed in " + "the pass."; + return false; + } + auto* input_desc = block->FindVar(desc.Input("Input").front()); + const auto input_shape = input_desc->GetShape(); + const auto head_number = + PADDLE_GET_CONST(int, desc.GetAttr("head_number")); + auto inputs = desc.Inputs(); + bool has_bias_qk = (inputs.find("BiasQK") == inputs.end()) ? false : true; + if (has_bias_qk) { + auto* biasqk_desc = block->FindVar(desc.Input("BiasQK").front()); + const auto biasqk_shape = biasqk_desc->GetShape(); + // The BiasQK's shape requires to be + // [batch, 1, 1, length] or [batch, head, length, length]. + bool has_same_shape = head_number == biasqk_shape[1] && + input_shape[1] == biasqk_shape[2] && + input_shape[1] == biasqk_shape[3]; + bool is_broadcastable = biasqk_shape[1] == 1 && biasqk_shape[2] == 1 && + input_shape[1] == biasqk_shape[3]; + if (!(has_same_shape || is_broadcastable)) { + VLOG(3) << "The BiasQK's shape is invalid, expect [" << input_shape[0] + << ", 1, 1, " << input_shape[1] << "] or [" << input_shape[0] + << ", " << head_number << ", " << input_shape[1] << ", " + << input_shape[1] << "] but [" << biasqk_shape[0] << ", " + << biasqk_shape[1] << ", " << biasqk_shape[2] << ", " + << biasqk_shape[3] << "]."; + return false; + } + } else { +#if !IS_TRT_VERSION_GE(8000) + VLOG(3) << "The version of TRT must be greater than 8000"; + return false; +#endif + } + } + if (op_type == "fc") { auto* block = desc.Block(); if (block == nullptr) { @@ -2271,6 +2323,7 @@ struct SimpleOpTypeSetTeller : public Teller { "clip", "fused_embedding_eltwise_layernorm", "multihead_matmul", + "multihead_matmul_roformer", "skip_layernorm", "slice", "strided_slice", @@ -2394,6 +2447,7 @@ struct SimpleOpTypeSetTeller : public Teller { "clip", "fused_embedding_eltwise_layernorm", "multihead_matmul", + "multihead_matmul_roformer", "skip_layernorm", "slice", "strided_slice", diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index a544d18a57d17bc214931aaec97f96724ee26179..a72880780d81e7cf136a53044fc2079cc80553ba 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -25,6 +25,7 @@ list( pool3d_op_plugin.cu deformable_conv_op_plugin.cu matmul_op_int8_plugin.cu + multihead_matmul_roformer_plugin.cu transformer_input_convert_plugin.cu remove_padding_plugin.cu recover_padding_plugin.cu diff --git a/paddle/fluid/inference/tensorrt/plugin/common/common.cuh b/paddle/fluid/inference/tensorrt/plugin/common/common.cuh index 10bf23fbc531bd963d8a6894af23c74a0586c00f..480d3a5733c2a96994d6838f6ca49f0223ba6f5f 100644 --- a/paddle/fluid/inference/tensorrt/plugin/common/common.cuh +++ b/paddle/fluid/inference/tensorrt/plugin/common/common.cuh @@ -18,6 +18,7 @@ #include #include "cublas_v2.h" +#include "paddle/fluid/platform/device_context.h" using kv_float = cub::KeyValuePair; using kv_half = cub::KeyValuePair; @@ -144,3 +145,154 @@ __device__ inline void layerNorm(const kvp& threadData, output[idx] = g * (val - mu) * rsigma + b; } } + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { +// Helper Functions for multihead related plugins +template +__global__ void transpose(T *src, + T *dst, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head) { + int batch_id = blockIdx.x / (head_num * seq_len); + int seq_id = blockIdx.x % seq_len; + int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len; + dst[batch_id * (head_num * seq_len * size_per_head) + + seq_id * head_num * size_per_head + head_id * size_per_head + + threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x]; +} + +template +__global__ void TransposeQkvKernel(const int H, const T *input, T *output) { + // Input: BxSx3xNxH + // Bias: 3xSxB + // Output: 3xBxNxSxH + int n = threadIdx.y; + int s = blockIdx.x; + int b = blockIdx.y; + int m = blockIdx.z; + + const int N = blockDim.y; + const int S = gridDim.x; + const int B = gridDim.y; + + const int NH = N * H; + const int NHS = NH * S; + const int in_offset = n * H + m * NH + s * 3 * NH + b * NHS * 3; + const int out_offset = s * H + n * S * H + b * NHS + m * NHS * B; + + const int i = threadIdx.x; + output[out_offset + i] = input[in_offset + i]; +} + +inline void TransposeQKV(const int batch, + const int seq_len, + const int head_size, + const int head_num, + const float *input, + float *output, + cudaStream_t stream) { + int scratch_size = batch * head_num * seq_len * seq_len; + const dim3 grid(seq_len, batch, 3); + if (head_size % 4 == 0 && scratch_size % 4 == 0) { + const int h = head_size / 4; + const float4 *input4 = reinterpret_cast(input); + float4 *output4 = reinterpret_cast(output); + const dim3 block(h, head_num, 1); + // limit h * head_num to max block size(1024). + PADDLE_ENFORCE_LE(h * head_num, + 1024, + platform::errors::InvalidArgument( + "head_num (%d) * head_size (%d) should <= %d", + head_num, + head_size, + 1024 * 4)); + TransposeQkvKernel<<>>(h, input4, output4); + } else if (head_size % 2 == 0 && scratch_size % 2 == 0) { + const int h = head_size / 2; + const float2 *input2 = reinterpret_cast(input); + float2 *output2 = reinterpret_cast(output); + const dim3 block(h, head_num, 1); + // limit h * head_num to max block size(1024). + PADDLE_ENFORCE_LE(h * head_num, + 1024, + platform::errors::InvalidArgument( + "head_num (%d) * head_size (%d) should <= %d", + head_num, + head_size, + 1024 * 2)); + TransposeQkvKernel<<>>(h, input2, output2); + } else { + const dim3 block(head_size, head_num, 1); + // limit head_size * head_num to max block size(1024). + PADDLE_ENFORCE_LE(head_size * head_num, + 1024, + platform::errors::InvalidArgument( + "head_num (%d) * head_size (%d) should <= %d", + head_num, + head_size, + 1024)); + TransposeQkvKernel + <<>>(head_size, input, output); + } +} + +inline void TransposeQKV(const int batch, + const int seq_len, + const int head_size, + const int head_num, + const half *input, + half *output, + cudaStream_t stream) { + int scratch_size = batch * head_num * seq_len * seq_len; + const dim3 grid(seq_len, batch, 3); + if (head_size % 8 == 0 && scratch_size % 8 == 0) { + int h = head_size / 8; + const int4 *input4 = reinterpret_cast(input); + int4 *output4 = reinterpret_cast(output); + dim3 block(h, head_num, 1); + // limit h * head_num to max block size(1024). + PADDLE_ENFORCE_LE(h * head_num, + 1024, + platform::errors::InvalidArgument( + "head_num (%d) * head_size (%d) should <= %d", + head_num, + head_size, + 1024 * 8)); + TransposeQkvKernel<<>>(h, input4, output4); + } else if (head_size % 2 == 0 && scratch_size % 2 == 0) { + const int h = head_size / 2; + const half2 *input2 = reinterpret_cast(input); + half2 *output2 = reinterpret_cast(output); + const dim3 block(h, head_num, 1); + // limit h * head_num to max block size(1024). + PADDLE_ENFORCE_LE(h * head_num, + 1024, + platform::errors::InvalidArgument( + "head_num (%d) * head_size (%d) should <= %d", + head_num, + head_size, + 1024 * 2)); + TransposeQkvKernel<<>>(h, input2, output2); + } else { + const dim3 block(head_size, head_num, 1); + // limit head_size * head_num to max block size(1024). + PADDLE_ENFORCE_LE(head_size * head_num, + 1024, + platform::errors::InvalidArgument( + "head_num (%d) * head_size (%d) should <= %d", + head_num, + head_size, + 1024)); + TransposeQkvKernel + <<>>(head_size, input, output); + } +} +} +} +} +} diff --git a/paddle/fluid/inference/tensorrt/plugin/multihead_matmul_roformer_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/multihead_matmul_roformer_plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..50f263ad61735e925c94ad8fb104de35258763e2 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/multihead_matmul_roformer_plugin.cu @@ -0,0 +1,381 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tensorrt/plugin/multihead_matmul_roformer_plugin.h" +#include +#include +#include // NOLINT +#include +#include "glog/logging.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h" +#include "paddle/fluid/operators/math/bert_encoder_functor.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +// Dynamic Plugin below. +#if IS_TRT_VERSION_GE(6000) + +int MultiheadMatmulRoformerPlugin::initialize() TRT_NOEXCEPT { return 0; } + +nvinfer1::DimsExprs MultiheadMatmulRoformerPlugin::getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs *inputs, + int nb_inputs, + nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { + // input[0], (B, S, 3 * N * H, 1, 1) + // input[1], (B, head_num, seq_len, seq_len) + // output, (B, seq_len, hidden) + PADDLE_ENFORCE_EQ(output_index, + 0, + platform::errors::InvalidArgument( + "There is only one output of the EmbEltwiseLayernorm, " + "so the index should be zero," + "but it's (%d)", + output_index)); + PADDLE_ENFORCE_EQ( + nb_inputs, + 4, + platform::errors::InvalidArgument( + "The Input of the EmbEltwiseLayernorm should be 3, but we found " + "it has (%d) inputs", + nb_inputs)); + nvinfer1::DimsExprs ret; + ret.nbDims = 3; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + ret.d[2] = expr_builder.constant(head_size_ * head_number_); + return ret; +} + +bool MultiheadMatmulRoformerPlugin::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc *in_out, + int nb_inputs, + int nb_outputs) TRT_NOEXCEPT { + PADDLE_ENFORCE_NOT_NULL( + in_out, + platform::errors::InvalidArgument( + "The input of swish plugin shoule not be nullptr.")); + + PADDLE_ENFORCE_LT( + pos, + nb_inputs + nb_outputs, + platform::errors::InvalidArgument("The pos(%d) should be less than the " + "num(%d) of the input and the output.", + pos, + nb_inputs + nb_outputs)); + + const nvinfer1::PluginTensorDesc &in = in_out[pos]; + if (pos == 0) { + if (with_fp16_) { +#ifdef TRT_PLUGIN_FP16_AVALIABLE + return (in.type == nvinfer1::DataType::kFLOAT || + in.type == nvinfer1::DataType::kHALF) && + (in.format == nvinfer1::TensorFormat::kLINEAR); +#else + return (in.type == nvinfer1::DataType::kFLOAT) && + (in.format == nvinfer1::TensorFormat::kLINEAR); +#endif + } else { + return (in.type == nvinfer1::DataType::kFLOAT) && + (in.format == nvinfer1::TensorFormat::kLINEAR); + } + } + const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; + + if (pos == 1) { + return in.type == prev.type && in.format == prev.format; + } + + // output + return in.type == prev.type && in.format == prev.format; +} + +nvinfer1::DataType MultiheadMatmulRoformerPlugin::getOutputDataType( + int index, + const nvinfer1::DataType *input_types, + int nb_inputs) const TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ( + index, + 0, + platform::errors::InvalidArgument( + "The EmbEltwiseLayernorm Plugin only has one input, so the " + "index value should be 0, but get %d.", + index)); + return input_types[0]; +} + +template +__global__ void apply_scale(T *data, T scale, int n) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) { + data[tid] = data[tid] * scale; + } +#endif +} + +template +__global__ void RotrayKernel(const T *inputact, + const T *input1, + const T *intput2, + T *output, + const int nElement, + const int lastdim) { + const int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= nElement) return; + T left_elemul_out = input1[index] * inputact[index]; + int col = index % lastdim; + int half_lastdim = lastdim / 2; + const int right_index = index - col + (col + half_lastdim) % lastdim; + output[index] = left_elemul_out + intput2[index] * inputact[right_index]; +} + +inline int round_up(int seq_len, int multiple = 32) { + PADDLE_ENFORCE_GT( + multiple, + 0, + platform::errors::InvalidArgument( + "multiple should be a positive number,but it's (%d)", multiple)); + return ((seq_len + multiple - 1) / multiple) * multiple; +} + +template +__global__ void broadcast(const T *src, + T *dst, + const int seq_len, + const int head_num) { + int batch_id = blockIdx.x / (head_num * seq_len); + int dst_offset = blockIdx.x * seq_len; + if (threadIdx.x < seq_len) { + dst[threadIdx.x + dst_offset] = src[threadIdx.x + batch_id * seq_len]; + } +} + +int MultiheadMatmulRoformerPlugin::enqueue( + const nvinfer1::PluginTensorDesc *input_desc, + const nvinfer1::PluginTensorDesc *output_desc, + const void *const *inputs, + void *const *outputs, + void *workspace, + cudaStream_t stream) TRT_NOEXCEPT { + auto input_dims = input_desc[0].dims; + int input_num = ProductDim(input_dims); + // input[0], (B, S, 3 * N * H, 1, 1) + int batch = input_dims.d[0]; + int seq_len = input_dims.d[1]; + phi::DenseTensor multihead_temp_tensor; + // masks + int scratch_size = batch * head_number_ * seq_len * seq_len * 1; + + int device_id; + cudaGetDevice(&device_id); + multihead_temp_tensor.Resize({scratch_size + input_num}); + // for roformer + phi::DenseTensor temp_roformer_tensor; + temp_roformer_tensor.Resize({input_num}); + + auto input_type = input_desc[0].type; + if (input_type == nvinfer1::DataType::kFLOAT) { + VLOG(1) << "TRT Plugin DataType selected. RoformerQkvToContext-->fp32"; + auto *multihead_temp_data = multihead_temp_tensor.mutable_data( + platform::CUDAPlace(device_id)); + auto *temp_roformer_data = + temp_roformer_tensor.mutable_data( // NOLINT + platform::CUDAPlace(device_id)); + auto *tmp_roformer_ptr = reinterpret_cast(temp_roformer_data); + auto *qkptr = multihead_temp_data; + auto *tptr = multihead_temp_data + scratch_size; + + const float *input0_data = static_cast(inputs[0]); + // fit to [batch, head_num, length, length] + [batch, 1, 1, length] + phi::DenseTensor temp_qk_bias_tensor; + float *qk_bias = const_cast(static_cast(inputs[3])); + if (ProductDim(input_desc[3].dims) == (batch * seq_len)) { + temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len}); + auto *temp_qk_bias = temp_qk_bias_tensor.mutable_data( + platform::CUDAPlace(device_id)); + int grid = batch * head_number_ * seq_len; + int block = round_up(seq_len); + broadcast<<>>( + static_cast(inputs[3]), + temp_qk_bias, + seq_len, + head_number_); + qk_bias = temp_qk_bias; + } + const float *input3_data = static_cast(qk_bias); + // BxSx3xNxH => tptr: 3xBxNxSxH. + TransposeQKV( + batch, seq_len, head_size_, head_number_, input0_data, tptr, stream); + cudaMemcpy(tmp_roformer_ptr, // dst + tptr, // src + input_num * sizeof(float), + cudaMemcpyDeviceToDevice); + int n_q = seq_len * head_number_ * head_size_ * batch; + constexpr int threads = 128; + int blocks = (n_q + threads - 1) / threads; + const float *input_cos_data = static_cast(inputs[1]); + const float *input_sin_data = static_cast(inputs[2]); + RotrayKernel<<>>(tmp_roformer_ptr, + input_cos_data, + input_sin_data, + tptr, + n_q, + head_size_); // q + RotrayKernel<<>>(tmp_roformer_ptr + n_q, + input_cos_data, + input_sin_data, + tptr + n_q, + n_q, + head_size_); // k + + auto *device_ctx = static_cast( + platform::DeviceContextPool::Instance().Get( + platform::CUDAPlace(device_id))); + + const phi::GPUContext &dev_ctx = *device_ctx; + operators::math::MultiHeadGPUComputeFunctor multihead_compute_func; + multihead_compute_func(dev_ctx, + batch, + seq_len, + head_number_, + head_size_, + qkptr, + input3_data, + false, + tptr, + scale_, + static_cast(0.0)); + + int grid = batch * head_number_ * seq_len; + int block = head_size_; + float *output = static_cast(outputs[0]); + transpose<<>>( + tptr, output, batch, seq_len, head_number_, head_size_); + + } else if (input_type == nvinfer1::DataType::kHALF) { +#ifdef TRT_PLUGIN_FP16_AVALIABLE + VLOG(1) << "TRT Plugin DataType selected. QkvToContext-->fp16"; + auto *multihead_temp_data = + multihead_temp_tensor.mutable_data( // NOLINT + platform::CUDAPlace(device_id)); + + auto *temp_roformer_data = + temp_roformer_tensor.mutable_data( // NOLINT + platform::CUDAPlace(device_id)); + half *tmp_roformer_ptr = reinterpret_cast(temp_roformer_data); + half *qkptr = reinterpret_cast(multihead_temp_data); + half *tptr = qkptr + scratch_size; + + const half *input0_data = static_cast(inputs[0]); + // fit to [batch, head_num, length, length] + [batch, 1, 1, length] + phi::DenseTensor temp_qk_bias_tensor; + half *qk_bias = const_cast(static_cast(inputs[3])); + if (ProductDim(input_desc[3].dims) == (batch * seq_len)) { + temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len}); + auto *temp_qk_bias = + reinterpret_cast(temp_qk_bias_tensor.mutable_data( + platform::CUDAPlace(device_id))); + int grid = batch * head_number_ * seq_len; + int block = round_up(seq_len); + broadcast<<>>( + static_cast(inputs[3]), + temp_qk_bias, + seq_len, + head_number_); + qk_bias = temp_qk_bias; + } + const half *input3_data = static_cast(qk_bias); + // BxSx3xNxH => tptr: 3xBxNxSxH. + TransposeQKV( + batch, seq_len, head_size_, head_number_, input0_data, tptr, stream); + cudaMemcpy(tmp_roformer_ptr, + tptr, + input_num * sizeof(half), + cudaMemcpyDeviceToDevice); + + auto *device_ctx = static_cast( + platform::DeviceContextPool::Instance().Get( + platform::CUDAPlace(device_id))); + + int n_q = seq_len * head_number_ * head_size_ * batch; + constexpr int threads = 128; + int blocks = (n_q + threads - 1) / threads; + + const half *input_cos_data = static_cast(inputs[1]); + const half *input_sin_data = static_cast(inputs[2]); + RotrayKernel<<>>(tmp_roformer_ptr, + input_cos_data, + input_sin_data, + tptr, + n_q, + head_size_); // q + RotrayKernel<<>>(tmp_roformer_ptr + n_q, + input_cos_data, + input_sin_data, + tptr + n_q, + n_q, + head_size_); // k + + apply_scale<<>>( + tptr, static_cast(scale_), n_q); + + const phi::GPUContext &dev_ctx = *device_ctx; + operators::math::MultiHeadGPUComputeFunctor multihead_compute_func; + multihead_compute_func(dev_ctx, + batch, + seq_len, + head_number_, + head_size_, + qkptr, + input3_data, + false, + tptr, + half(1.), + half(0.0)); + + int grid = batch * head_number_ * seq_len; + int block = head_size_; + half *output = static_cast(outputs[0]); + transpose<<>>( + tptr, output, batch, seq_len, head_number_, head_size_); +#else + PADDLE_THROW(platform::errors::Fatal( + "The Ernie(Bert) TensorRT Plugin should be " + "complied with CUDA version >= 10.0 when running with fp16. " + "Please recomplie it or try to use fp32 by set " + "config.SetTRTDynamicShapeInfo(min_input_shape, " + "max_input_shape, opt_input_shape, true")); +#endif + } else { + PADDLE_THROW(platform::errors::Fatal( + "The QKV TRT Plugin's input type should be float or half.")); + } + return cudaGetLastError() != cudaSuccess; +} +#endif + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/multihead_matmul_roformer_plugin.h b/paddle/fluid/inference/tensorrt/plugin/multihead_matmul_roformer_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..3f2a106fcc969f600b9d862fe15e459316edbfd2 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/multihead_matmul_roformer_plugin.h @@ -0,0 +1,163 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +#if IS_TRT_VERSION_GE(6000) +class MultiheadMatmulRoformerPlugin : public DynamicPluginTensorRT { + public: + explicit MultiheadMatmulRoformerPlugin( + int hidden, int head_number, int head_size, float scale, bool with_fp16) + : hidden_(hidden), + head_number_(head_number), + head_size_(head_size), + scale_(scale) { + with_fp16_ = with_fp16; + } + + MultiheadMatmulRoformerPlugin(void const* serial_data, size_t serial_length) { + DeserializeValue(&serial_data, &serial_length, &hidden_); + DeserializeValue(&serial_data, &serial_length, &head_number_); + DeserializeValue(&serial_data, &serial_length, &head_size_); + DeserializeValue(&serial_data, &serial_length, &scale_); + DeserializeValue(&serial_data, &serial_length, &with_fp16_); + } + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { + return new MultiheadMatmulRoformerPlugin( + hidden_, head_number_, head_size_, scale_, with_fp16_); + } + + const char* getPluginType() const TRT_NOEXCEPT override { + return "multihead_matmul_roformer_plugin"; + } + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + int initialize() TRT_NOEXCEPT override; + + size_t getSerializationSize() const TRT_NOEXCEPT override { + return SerializedSize(hidden_) + SerializedSize(head_number_) + + SerializedSize(head_size_) + SerializedSize(scale_) + + SerializedSize(with_fp16_); + } + void serialize(void* buffer) const TRT_NOEXCEPT override { + SerializeValue(&buffer, hidden_); + SerializeValue(&buffer, head_number_); + SerializeValue(&buffer, head_size_); + SerializeValue(&buffer, scale_); + SerializeValue(&buffer, with_fp16_); + } + + nvinfer1::DimsExprs getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs* inputs, + int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) // NOLINT + TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* in_out, + int nb_inputs, + int nb_outputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nb_inputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nb_outputs) TRT_NOEXCEPT override {} + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nb_inputs, + const nvinfer1::PluginTensorDesc* outputs, + int nb_outputs) const TRT_NOEXCEPT override { + return 0; + } + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* input_types, + int nb_inputs) const + TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override { delete this; } + + private: + int hidden_; + int head_number_; + int head_size_; + float scale_; +}; + +class MultiheadMatmulRoformerPluginCreator : public nvinfer1::IPluginCreator { + public: + MultiheadMatmulRoformerPluginCreator() {} + const char* getPluginName() const TRT_NOEXCEPT override { + return "multihead_matmul_roformer_plugin"; + } + + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + + const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override { + return &field_collection_; + } + + nvinfer1::IPluginV2* createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override { + return nullptr; + } + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) + TRT_NOEXCEPT override { + auto plugin = new MultiheadMatmulRoformerPlugin(serial_data, serial_length); + return plugin; + } + + void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override { + plugin_namespace_ = lib_namespace; + } + + const char* getPluginNamespace() const TRT_NOEXCEPT override { + return plugin_namespace_.c_str(); + } + + private: + std::string plugin_namespace_; + std::string plugin_name_; + nvinfer1::PluginFieldCollection field_collection_; + std::vector plugin_attributes_; +}; +REGISTER_TRT_PLUGIN_V2(MultiheadMatmulRoformerPluginCreator); +#endif + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu index 27e40985d95f05d3ca317c564958fe26f2c5d53a..8cb8b7f4b7e2044a706ec80abc052d62afa7e8cb 100644 --- a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu @@ -21,6 +21,7 @@ #include "glog/logging.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh" #include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h" @@ -35,21 +36,6 @@ namespace plugin { // Dynamic Plugin below. #if IS_TRT_VERSION_GE(6000) -template -__global__ void transpose(T *src, - T *dst, - const int batch_size, - const int seq_len, - const int head_num, - const int size_per_head) { - int batch_id = blockIdx.x / (head_num * seq_len); - int seq_id = blockIdx.x % seq_len; - int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len; - dst[batch_id * (head_num * seq_len * size_per_head) + - seq_id * head_num * size_per_head + head_id * size_per_head + - threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x]; -} - inline int round_up(int seq_len, int multiple = 32) { PADDLE_ENFORCE_GT( multiple, @@ -115,133 +101,6 @@ __global__ void transpose_qkv_unpadding(const T *src, seq_id * size_per_head + threadIdx.x]; } -template -__global__ void TransposeQkvKernel(const int H, const T *input, T *output) { - // Input: BxSx3xNxH - // Bias: 3xSxB - // Output: 3xBxNxSxH - int n = threadIdx.y; - int s = blockIdx.x; - int b = blockIdx.y; - int m = blockIdx.z; - - const int N = blockDim.y; - const int S = gridDim.x; - const int B = gridDim.y; - - const int NH = N * H; - const int NHS = NH * S; - const int in_offset = n * H + m * NH + s * 3 * NH + b * NHS * 3; - const int out_offset = s * H + n * S * H + b * NHS + m * NHS * B; - - const int i = threadIdx.x; - output[out_offset + i] = input[in_offset + i]; -} - -inline void TransposeQKV(const int batch, - const int seq_len, - const int head_size, - const int head_num, - const float *input, - float *output, - cudaStream_t stream) { - int scratch_size = batch * head_num * seq_len * seq_len; - const dim3 grid(seq_len, batch, 3); - if (head_size % 4 == 0 && scratch_size % 4 == 0) { - const int h = head_size / 4; - const float4 *input4 = reinterpret_cast(input); - float4 *output4 = reinterpret_cast(output); - const dim3 block(h, head_num, 1); - // limit h * head_num to max block size(1024). - PADDLE_ENFORCE_LE(h * head_num, - 1024, - platform::errors::InvalidArgument( - "head_num (%d) * head_size (%d) should <= %d", - head_num, - head_size, - 1024 * 4)); - TransposeQkvKernel<<>>(h, input4, output4); - } else if (head_size % 2 == 0 && scratch_size % 2 == 0) { - const int h = head_size / 2; - const float2 *input2 = reinterpret_cast(input); - float2 *output2 = reinterpret_cast(output); - const dim3 block(h, head_num, 1); - // limit h * head_num to max block size(1024). - PADDLE_ENFORCE_LE(h * head_num, - 1024, - platform::errors::InvalidArgument( - "head_num (%d) * head_size (%d) should <= %d", - head_num, - head_size, - 1024 * 2)); - TransposeQkvKernel<<>>(h, input2, output2); - } else { - const dim3 block(head_size, head_num, 1); - // limit head_size * head_num to max block size(1024). - PADDLE_ENFORCE_LE(head_size * head_num, - 1024, - platform::errors::InvalidArgument( - "head_num (%d) * head_size (%d) should <= %d", - head_num, - head_size, - 1024)); - TransposeQkvKernel - <<>>(head_size, input, output); - } -} - -inline void TransposeQKV(const int batch, - const int seq_len, - const int head_size, - const int head_num, - const half *input, - half *output, - cudaStream_t stream) { - int scratch_size = batch * head_num * seq_len * seq_len; - const dim3 grid(seq_len, batch, 3); - if (head_size % 8 == 0 && scratch_size % 8 == 0) { - int h = head_size / 8; - const int4 *input4 = reinterpret_cast(input); - int4 *output4 = reinterpret_cast(output); - dim3 block(h, head_num, 1); - // limit h * head_num to max block size(1024). - PADDLE_ENFORCE_LE(h * head_num, - 1024, - platform::errors::InvalidArgument( - "head_num (%d) * head_size (%d) should <= %d", - head_num, - head_size, - 1024 * 8)); - TransposeQkvKernel<<>>(h, input4, output4); - } else if (head_size % 2 == 0 && scratch_size % 2 == 0) { - const int h = head_size / 2; - const half2 *input2 = reinterpret_cast(input); - half2 *output2 = reinterpret_cast(output); - const dim3 block(h, head_num, 1); - // limit h * head_num to max block size(1024). - PADDLE_ENFORCE_LE(h * head_num, - 1024, - platform::errors::InvalidArgument( - "head_num (%d) * head_size (%d) should <= %d", - head_num, - head_size, - 1024 * 2)); - TransposeQkvKernel<<>>(h, input2, output2); - } else { - const dim3 block(head_size, head_num, 1); - // limit head_size * head_num to max block size(1024). - PADDLE_ENFORCE_LE(head_size * head_num, - 1024, - platform::errors::InvalidArgument( - "head_num (%d) * head_size (%d) should <= %d", - head_num, - head_size, - 1024)); - TransposeQkvKernel - <<>>(head_size, input, output); - } -} - int QkvToContextPluginDynamic::initialize() TRT_NOEXCEPT { return 0; } nvinfer1::DimsExprs QkvToContextPluginDynamic::getOutputDimensions( diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_multihead_matmul_roformer_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_multihead_matmul_roformer_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..58ab3bc1fb34d47acb205305cf0ae520b0f4a201 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_multihead_matmul_roformer_fuse_pass.py @@ -0,0 +1,389 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig +import paddle.inference as paddle_infer +import numpy as np +from functools import partial +import unittest + + +class TestMultiheadMatmulRoformerFusePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + # trt + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=8, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False, + ) + config.set_trt_dynamic_shape_info( + { + "mul_x": [1, 1, 768], + "eltadd_qk_b_var": [1, 12, 1, 1], + "cos_input": [1, 12, 1, 64], + "sin_input": [1, 12, 1, 64], + }, + { + "mul_x": [1, 128, 768], + "eltadd_qk_b_var": [1, 12, 128, 128], + "cos_input": [1, 12, 128, 64], + "sin_input": [1, 12, 128, 64], + }, + { + "mul_x": [1, 128, 768], + "eltadd_qk_b_var": [1, 12, 128, 128], + "cos_input": [1, 12, 128, 64], + "sin_input": [1, 12, 128, 64], + }, + ) + yield config, ["multihead_matmul_roformer", "matmul"], (1e-2, 1e-3) + + def sample_program_config(self, draw): + def generate_mul_input(): + return ( + np.random.random([1, 128, 768]).astype(np.float32) - 0.5 + ) / 100.0 + + def generate_elewise_input(): + return ( + np.random.random([1, 12, 128, 128]).astype(np.float32) + ) / 100.0 + + def generate_cos_input(): + return np.random.random([1, 12, 128, 64]).astype(np.float32) - 0.5 + + def generate_sin_input(): + return np.random.random([1, 12, 128, 64]).astype(np.float32) - 0.5 + + def generate_weight1(): + return ( + np.random.random((768, 768)).astype(np.float32) - 0.5 + ) / 100.0 + + def generate_weight2(): + return (np.random.random(768).astype(np.float32) - 0.5) / 100.0 + + mul_0 = OpConfig( + "matmul", + inputs={"X": ["mul_x"], "Y": ["mul_0_w"]}, + outputs={"Out": ["mul_0_out"]}, + alpha=1.0, + transpose_X=False, + transpose_Y=False, + ) + mul_1 = OpConfig( + "matmul", + inputs={"X": ["mul_x"], "Y": ["mul_1_w"]}, + outputs={"Out": ["mul_1_out"]}, + alpha=1.0, + transpose_X=False, + transpose_Y=False, + ) + mul_2 = OpConfig( + "matmul", + inputs={"X": ["mul_x"], "Y": ["mul_2_w"]}, + outputs={"Out": ["mul_2_out"]}, + alpha=1.0, + transpose_X=False, + transpose_Y=False, + ) + ele_0 = OpConfig( + "elementwise_add", + inputs={"X": [mul_0.outputs["Out"][0]], "Y": ["ele_0_w"]}, + outputs={"Out": ["ele_0_out"]}, + axis=-1, + ) + ele_1 = OpConfig( + "elementwise_add", + inputs={"X": [mul_1.outputs["Out"][0]], "Y": ["ele_1_w"]}, + outputs={"Out": ["ele_1_out"]}, + axis=-1, + ) + ele_2 = OpConfig( + "elementwise_add", + inputs={"X": [mul_2.outputs["Out"][0]], "Y": ["ele_2_w"]}, + outputs={"Out": ["ele_2_out"]}, + axis=-1, + ) + reshape_0 = OpConfig( + "reshape2", + inputs={"X": [ele_0.outputs["Out"][0]]}, + outputs={"Out": ["reshape_0_out"], "XShape": ["reshape_0_Xout"]}, + shape=(1, 128, 12, 64), + ) + reshape_1 = OpConfig( + "reshape2", + inputs={"X": [ele_1.outputs["Out"][0]]}, + outputs={"Out": ["reshape_1_out"], "XShape": ["reshape_1_Xout"]}, + shape=(1, 128, 12, 64), + ) + reshape_2 = OpConfig( + "reshape2", + inputs={"X": [ele_2.outputs["Out"][0]]}, + outputs={"Out": ["reshape_2_out"], "XShape": ["reshape_2_Xout"]}, + shape=(1, 128, 12, 64), + ) + transpose_0 = OpConfig( + "transpose2", + inputs={"X": [reshape_0.outputs["Out"][0]]}, + outputs={"Out": ["transpose_0_out"]}, + axis=(0, 2, 1, 3), + ) + transpose_1 = OpConfig( + "transpose2", + inputs={"X": [reshape_1.outputs["Out"][0]]}, + outputs={"Out": ["transpose_1_out"]}, + axis=(0, 2, 1, 3), + ) + transpose_2 = OpConfig( + "transpose2", + inputs={"X": [reshape_2.outputs["Out"][0]]}, + outputs={"Out": ["transpose_2_out"]}, + axis=(0, 2, 1, 3), + ) + + # roformer part + # q with scale branch + ele_mul_q_0 = OpConfig( + "elementwise_mul", # without split && concat + inputs={"X": [transpose_0.outputs["Out"][0]], "Y": ["cos_input"]}, + outputs={"Out": ["ele_mul_q_0_out"]}, + axis=-1, + ) + + split_q_0 = OpConfig( + "split", + inputs={"X": [transpose_0.outputs["Out"][0]]}, + outputs={"Out": ["split_q_0_out_0", "split_q_0_out_1"]}, + axis=3, + num=2, + ) + + concat_q_0 = OpConfig( + "concat", + inputs={ + "X": [split_q_0.outputs["Out"][1], split_q_0.outputs["Out"][0]] + }, + outputs={"Out": ["concat_q_0_out"]}, + axis=-1, + ) + + ele_mul_q_1 = OpConfig( + "elementwise_mul", # without split && concat + inputs={"X": [concat_q_0.outputs["Out"][0]], "Y": ["sin_input"]}, + outputs={"Out": ["ele_mul_q_1_out"]}, + axis=-1, + ) + + ele_add_q_0 = OpConfig( + "elementwise_add", + inputs={ + "X": [ele_mul_q_0.outputs["Out"][0]], + "Y": [ele_mul_q_1.outputs["Out"][0]], + }, + outputs={"Out": ["ele_add_q_0_out"]}, + axis=-1, + ) + + scale_0 = OpConfig( + "scale", + inputs={"X": [ele_add_q_0.outputs["Out"][0]]}, + outputs={"Out": ["scale_0_out"]}, + scale=0.1961161345243454, + bias=0, + ) + + # k branch which without scale op + ele_mul_k_0 = OpConfig( + "elementwise_mul", # without split && concat + inputs={"X": [transpose_1.outputs["Out"][0]], "Y": ["cos_input"]}, + outputs={"Out": ["ele_mul_k_0_out"]}, + axis=-1, + ) + + split_k_0 = OpConfig( + "split", + inputs={"X": [transpose_1.outputs["Out"][0]]}, + outputs={"Out": ["split_k_0_out_0", "split_k_0_out_1"]}, + axis=3, + num=2, + ) + + concat_k_0 = OpConfig( + "concat", + inputs={ + "X": [split_k_0.outputs["Out"][1], split_k_0.outputs["Out"][0]] + }, + outputs={"Out": ["concat_k_0_out"]}, + axis=-1, + ) + + ele_mul_k_1 = OpConfig( + "elementwise_mul", # with split && concat + inputs={"X": [concat_k_0.outputs["Out"][0]], "Y": ["sin_input"]}, + outputs={"Out": ["ele_mul_k_1_out"]}, + axis=-1, + ) + + ele_add_k_0 = OpConfig( + "elementwise_add", + inputs={ + "X": [ele_mul_k_0.outputs["Out"][0]], + "Y": [ele_mul_k_1.outputs["Out"][0]], + }, + outputs={"Out": ["ele_add_k_0_out"]}, + axis=-1, + ) + + matmul_0 = OpConfig( + "matmul", + inputs={ + "X": [scale_0.outputs["Out"][0]], + "Y": [ele_add_k_0.outputs["Out"][0]], + }, + outputs={"Out": ["matmul_0_out"]}, + alpha=1.0, + transpose_X=False, + transpose_Y=True, + fused_reshape_Out=[], + fused_reshape_X=[], + fused_reshape_Y=[], + fused_transpose_Out=[], + fused_transpose_X=[], + fused_transpose_Y=[], + ) + ele_3 = OpConfig( + "elementwise_add", + inputs={ + "X": [matmul_0.outputs["Out"][0]], + "Y": ["eltadd_qk_b_var"], + }, + outputs={"Out": ["ele_3_out"]}, + axis=-1, + ) + softmax_op = OpConfig( + "softmax", + inputs={"X": [ele_3.outputs["Out"][0]]}, + outputs={"Out": ["softmax_out"]}, + axis=3, + is_test=True, + ) + matmul_1 = OpConfig( + "matmul", + inputs={ + "X": [softmax_op.outputs["Out"][0]], + "Y": [transpose_2.outputs["Out"][0]], + }, + outputs={"Out": ["matmul_1_out"]}, + alpha=1.0, + transpose_X=False, + transpose_Y=False, + ) + transpose_3 = OpConfig( + "transpose2", + inputs={"X": [matmul_1.outputs["Out"][0]]}, + outputs={"Out": ["transpose_3_out"]}, + axis=(0, 2, 1, 3), + ) + reshape_3 = OpConfig( + "reshape2", + inputs={"X": [transpose_3.outputs["Out"][0]]}, + outputs={"Out": ["reshape_3_out"], "XShape": ["reshape_3_Xout"]}, + shape=(1, 128, 768), + ) + mul_3 = OpConfig( + "matmul", + inputs={"X": [reshape_3.outputs["Out"][0]], "Y": ["mul_3_w"]}, + outputs={"Out": ["mul_3_out"]}, + alpha=1.0, + transpose_X=False, + transpose_Y=False, + fused_reshape_Out=[], + fused_reshape_X=[], + fused_reshape_Y=[], + fused_transpose_Out=[], + fused_transpose_X=[], + fused_transpose_Y=[], + ) + ops = [ + mul_0, + mul_1, + mul_2, + ele_0, + ele_1, + ele_2, + reshape_0, + reshape_1, + reshape_2, + transpose_0, + transpose_1, + transpose_2, + ele_mul_q_0, + split_q_0, + concat_q_0, + ele_mul_q_1, + ele_add_q_0, + ele_mul_k_0, + split_k_0, + concat_k_0, + ele_mul_k_1, + ele_add_k_0, + scale_0, + matmul_0, + ele_3, + softmax_op, + matmul_1, + transpose_3, + reshape_3, + mul_3, + ] + program_config = ProgramConfig( + ops=ops, + inputs={ + "mul_x": TensorConfig(data_gen=partial(generate_mul_input)), + "eltadd_qk_b_var": TensorConfig( + data_gen=partial(generate_elewise_input) + ), + "cos_input": TensorConfig(data_gen=partial(generate_cos_input)), + "sin_input": TensorConfig(data_gen=partial(generate_sin_input)), + }, + weights={ # generate_weight1 + "mul_0_w": TensorConfig(data_gen=partial(generate_weight1)), + "mul_1_w": TensorConfig(data_gen=partial(generate_weight1)), + "mul_2_w": TensorConfig(data_gen=partial(generate_weight1)), + "mul_3_w": TensorConfig(data_gen=partial(generate_weight1)), + "ele_0_w": TensorConfig(data_gen=partial(generate_weight2)), + "ele_1_w": TensorConfig(data_gen=partial(generate_weight2)), + "ele_2_w": TensorConfig(data_gen=partial(generate_weight2)), + }, + outputs=[ops[-1].outputs["Out"][0]], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=100, + min_success_num=1, + passes=["multihead_matmul_roformer_fuse_pass"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul_roformer.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul_roformer.py new file mode 100644 index 0000000000000000000000000000000000000000..25bdfb71c7ff03cfa8fb3e8fe227dd0bac2c4db5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul_roformer.py @@ -0,0 +1,563 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest +from program_config import TensorConfig, ProgramConfig +import unittest +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import List + + +class TrtConvertMultiHeadMatmulRoformerTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + def generate_input1(batch, dim1): + return ( + np.random.random((batch, dim1, 768)).astype(np.float32) - 0.5 + ) / 100.0 + + def generate_input2(shape): + return (np.random.random(shape).astype(np.float32) - 0.5) / 100.0 + + def generate_cos_input(batch, dim1): + return ( + np.random.random((batch, 12, dim1, 64)).astype(np.float32) - 0.5 + ) + + def generate_sin_input(batch, dim1): + return ( + np.random.random((batch, 12, dim1, 64)).astype(np.float32) - 0.5 + ) + + def generate_weight1(): + return ( + np.random.random((768, 768)).astype(np.float32) - 0.5 + ) / 100.0 + + def generate_weight2(): + return (np.random.random(768).astype(np.float32) - 0.5) / 100.0 + + for batch in [1, 2, 4]: + self.batch = batch + for reshape_shape in [[0, 0, 12, 64]]: + for dim1 in [128]: + input2_shapes = [ + (batch, reshape_shape[2], dim1, dim1) + ] # 10,12,128,128 + # [batch, 1, 1, dim1]] + for input2_shape in input2_shapes: + for axis in [0]: + dics = [ + { + "alpha": 1.0, + "transpose_X": False, + "transpose_Y": False, + }, + {"axis": 2}, + {"shape": reshape_shape}, + {"axis": [0, 2, 1, 3]}, + { + "alpha": 1.0, + "transpose_X": False, + "transpose_Y": False, + }, + {"axis": 2}, + {"shape": reshape_shape}, + {"axis": [0, 2, 1, 3]}, + { + "alpha": 1.0, + "transpose_X": False, + "transpose_Y": False, + }, + {"axis": 2}, + {"shape": reshape_shape}, + {"axis": [0, 2, 1, 3]}, + { + "scale": 0.125, + "bias": 0.0, + "bias_after_scale": True, + }, + { + "alpha": 1.0, + "transpose_X": False, + "transpose_Y": True, + "fused_reshape_X": [], + "fused_reshape_Y": [], + "fused_transpose_X": [], + "fused_transpose_Y": [], + "fused_reshape_Out": [], + "fused_transpose_Out": [], + }, + {"axis": axis}, + {"axis": -1, "is_test": True}, + { + "seed": 0, + "dropout_prob": 0.10000000149011612, + "dropout_implementation": "upscale_in_train", + "fix_seed": False, + "is_test": True, + }, + { + "alpha": 1.0, + "transpose_X": False, + "transpose_Y": False, + "fused_reshape_X": [], + "fused_reshape_Y": [], + "fused_transpose_X": [], + "fused_transpose_Y": [], + "fused_reshape_Out": [], + "fused_transpose_Out": [], + }, + {"axis": [0, 2, 1, 3]}, + {"shape": [0, 0, 768]}, + { + "alpha": 1.0, + "transpose_X": False, + "transpose_Y": False, + }, + ] + + ops_config = [ + { + "op_type": "matmul", + "op_inputs": { + "X": ["input_data1"], + "Y": ["mul1_weight"], + }, + "op_outputs": {"Out": ["mul1_output"]}, + "op_attrs": dics[0], + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["mul1_output"], + "Y": ["elementwise_add1_weight"], + }, + "op_outputs": { + "Out": ["elementwise_add1_output"] + }, + "op_attrs": dics[1], + }, + { + "op_type": "reshape2", + "op_inputs": { + "X": ["elementwise_add1_output"], + }, + "op_outputs": { + "Out": ["reshape21_output"], + "XShape": ["reshape21_output_xshape"], + }, + "op_attrs": dics[2], + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["reshape21_output"]}, + "op_outputs": { + "Out": ["transpose21_output"], + "XShape": ["transpose21_output_xshape"], + }, + "op_attrs": dics[3], + }, + # roformer part + # q with scale branch + { + "op_type": "elementwise_mul", + "op_inputs": { + "X": ["transpose21_output"], + "Y": ["cos_input"], + }, + "op_outputs": { + "Out": ["elementwise_mul_q_0_output"] + }, + "op_attrs": {"axis": -1}, + }, + { + "op_type": "split", + "op_inputs": { + "X": ["transpose21_output"], + }, + "op_outputs": { + "Out": [ + "split_q_0_output_0", + "split_q_0_output_1", + ], + }, + "op_attrs": { + "axis": 3, + "num": 2, + }, + }, + { + "op_type": "concat", + "op_inputs": { + "X": [ + "split_q_0_output_1", + "split_q_0_output_0", + ], + }, + "op_outputs": { + "Out": ["concat_q_0_output"] + }, + "op_attrs": {"axis": -1}, + }, + { + "op_type": "elementwise_mul", + "op_inputs": { + "X": ["concat_q_0_output"], + "Y": ["sin_input"], + }, + "op_outputs": { + "Out": ["elementwise_mul_q_1_output"] + }, + "op_attrs": {"axis": -1}, + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["elementwise_mul_q_0_output"], + "Y": ["elementwise_mul_q_1_output"], + }, + "op_outputs": { + "Out": ["elementwise_add_q_0_output"] + }, + "op_attrs": {"axis": -1}, + }, + { + "op_type": "scale", + "op_inputs": { + "X": ["elementwise_add_q_0_output"], + }, + "op_outputs": {"Out": ["scale_output"]}, + "op_attrs": dics[12], + }, + # k branch + { + "op_type": "matmul", + "op_inputs": { + "X": ["input_data1"], + "Y": ["mul2_weight"], + }, + "op_outputs": {"Out": ["mul2_output"]}, + "op_attrs": dics[4], + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["mul2_output"], + "Y": ["elementwise_add2_weight"], + }, + "op_outputs": { + "Out": ["elementwise_add2_output"] + }, + "op_attrs": dics[5], + }, + { + "op_type": "reshape2", + "op_inputs": { + "X": ["elementwise_add2_output"] + }, + "op_outputs": { + "Out": ["reshape22_output"], + "XShape": ["reshape22_output_xshape"], + }, + "op_attrs": dics[6], + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["reshape22_output"]}, + "op_outputs": { + "Out": ["transpose22_output"], + "XShape": ["transpose22_output_xshape"], + }, + "op_attrs": dics[7], + }, + # roformer part + # k without scale branch + { + "op_type": "elementwise_mul", + "op_inputs": { + "X": ["transpose22_output"], + "Y": ["cos_input"], + }, + "op_outputs": { + "Out": ["elementwise_mul_k_0_output"] + }, + "op_attrs": {"axis": -1}, + }, + { + "op_type": "split", + "op_inputs": { + "X": ["transpose22_output"], + }, + "op_outputs": { + "Out": [ + "split_k_0_output_0", + "split_k_0_output_1", + ] + }, + "op_attrs": {"axis": 3, "num": 2}, + }, + { + "op_type": "concat", + "op_inputs": { + "X": [ + "split_k_0_output_1", + "split_k_0_output_0", + ] + }, + "op_outputs": { + "Out": ["concat_k_0_output"] + }, + "op_attrs": { + "axis": -1, + }, + }, + { + "op_type": "elementwise_mul", + "op_inputs": { + "X": ["concat_k_0_output"], + "Y": ["sin_input"], + }, + "op_outputs": { + "Out": ["elementwise_mul_k_1_output"] + }, + "op_attrs": {"axis": -1}, + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["elementwise_mul_k_0_output"], + "Y": ["elementwise_mul_k_1_output"], + }, + "op_outputs": { + "Out": ["elementwise_add_k_0_output"] + }, + "op_attrs": {"axis": -1}, + }, + # v branch + { + "op_type": "matmul", + "op_inputs": { + "X": ["input_data1"], + "Y": ["mul3_weight"], + }, + "op_outputs": {"Out": ["mul3_output"]}, + "op_attrs": dics[8], + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["mul3_output"], + "Y": ["elementwise_add3_weight"], + }, + "op_outputs": { + "Out": ["elementwise_add3_output"] + }, + "op_attrs": dics[9], + }, + { + "op_type": "reshape2", + "op_inputs": { + "X": ["elementwise_add3_output"] + }, + "op_outputs": { + "Out": ["reshape23_output"], + "XShape": ["reshape23_output_xshape"], + }, + "op_attrs": dics[10], + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["reshape23_output"]}, + "op_outputs": { + "Out": ["transpose23_output"], + "XShape": ["transpose23_output_xshape"], + }, + "op_attrs": dics[11], + }, + { + "op_type": "matmul", + "op_inputs": { + "X": ["scale_output"], + "Y": ["elementwise_add_k_0_output"], + }, + "op_outputs": {"Out": ["matmul1_output"]}, + "op_attrs": dics[13], + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["matmul1_output"], + "Y": ["input_data2"], + }, + "op_outputs": { + "Out": ["elementwise_add4_output"] + }, + "op_attrs": {"axis": -1}, + }, + { + "op_type": "softmax", + "op_inputs": { + "X": ["elementwise_add4_output"] + }, + "op_outputs": {"Out": ["softmax_output"]}, + "op_attrs": dics[15], + }, + { + "op_type": "matmul", + "op_inputs": { + "X": ["softmax_output"], + "Y": ["transpose23_output"], + }, + "op_outputs": {"Out": ["matmul2_output"]}, + "op_attrs": dics[17], + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["matmul2_output"]}, + "op_outputs": { + "Out": ["transpose24_output"], + "XShape": ["transpose24_output_xshape"], + }, + "op_attrs": dics[18], + }, + { + "op_type": "reshape2", + "op_inputs": {"X": ["transpose24_output"]}, + "op_outputs": { + "Out": ["reshape24_output"], + "XShape": ["reshape24_output_xshape"], + }, + "op_attrs": dics[19], + }, + # In order to fuse ops with + # multihead_matmul_fuse_pass_v2, the last op + # must be mul. + { + "op_type": "matmul", + "op_inputs": { + "X": ["reshape24_output"], + "Y": ["mul4_weight"], + }, + "op_outputs": {"Out": ["mul4_output"]}, + "op_attrs": dics[20], + }, + ] + ops = self.generate_op_config(ops_config) + program_config = ProgramConfig( + ops=ops, + weights={ + "mul1_weight": TensorConfig( + data_gen=partial(generate_weight1) + ), + "mul2_weight": TensorConfig( + data_gen=partial(generate_weight1) + ), + "mul3_weight": TensorConfig( + data_gen=partial(generate_weight1) + ), + "mul4_weight": TensorConfig( + data_gen=partial(generate_weight1) + ), + "elementwise_add1_weight": TensorConfig( + data_gen=partial(generate_weight2) + ), + "elementwise_add2_weight": TensorConfig( + data_gen=partial(generate_weight2) + ), + "elementwise_add3_weight": TensorConfig( + data_gen=partial(generate_weight2) + ), + }, + inputs={ + "input_data1": TensorConfig( + data_gen=partial( + generate_input1, batch, dim1 + ) + ), + "input_data2": TensorConfig( + data_gen=partial( + generate_input2, input2_shape + ) + ), + "cos_input": TensorConfig( + data_gen=partial( + generate_cos_input, batch, dim1 + ) + ), + "sin_input": TensorConfig( + data_gen=partial( + generate_sin_input, batch, dim1 + ) + ), + }, + outputs=["mul4_output"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + # The last dim of input1 and input2 should be static. + self.dynamic_shape.min_input_shape = { + "input_data1": [1, 1, 768], + "input_data2": [1, 12, 1, 1], + "cos_input": [1, 12, 1, 64], + "sin_input": [1, 12, 1, 64], + "reshape24_output": [1, 1, 768], + } + self.dynamic_shape.max_input_shape = { + "input_data1": [10, 128, 768], + "input_data2": [10, 12, 128, 128], + "cos_input": [10, 12, 128, 64], + "sin_input": [10, 12, 128, 64], + "reshape24_output": [10, 128, 768], + } + self.dynamic_shape.opt_input_shape = { + "input_data1": [8, 128, 768], + "input_data2": [8, 12, 128, 128], + "cos_input": [8, 12, 128, 64], + "sin_input": [8, 12, 128, 64], + "reshape24_output": [8, 128, 768], + } + + def clear_dynamic_shape(): + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + self.trt_param.workspace_size = 2013265920 + yield self.create_inference_config(), (1, 5), (1e-3, 1e-3) + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), (1, 5), (1e-3, 1e-3) + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main()