From a48b8e2cc124e57a851377410c2adf05861e4064 Mon Sep 17 00:00:00 2001 From: Wang Bojun <105858416+wwbitejotunn@users.noreply.github.com> Date: Fri, 13 Jan 2023 18:55:27 +0800 Subject: [PATCH] add oss flash fmha and fmhca support (#49438) * add fmha_flashattention oss plugin --- paddle/fluid/framework/ir/CMakeLists.txt | 2 + .../trt_cross_multihead_matmul_fuse_pass.cc | 587 ++++++++++++++++++ .../ir/trt_cross_multihead_matmul_fuse_pass.h | 92 +++ .../trt_flash_multihead_matmul_fuse_pass.cc | 579 +++++++++++++++++ .../ir/trt_flash_multihead_matmul_fuse_pass.h | 91 +++ .../ir/trt_multihead_matmul_fuse_pass.cc | 3 + .../fluid/inference/api/analysis_predictor.cc | 4 + .../inference/api/paddle_pass_builder.cc | 3 + .../inference/tensorrt/convert/CMakeLists.txt | 2 + .../convert/cross_multihead_matmul_op.cc | 277 +++++++++ .../convert/flash_multihead_matmul_op.cc | 190 ++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 6 + ...test_trt_convert_cross_multihead_matmul.py | 326 ++++++++++ ...test_trt_convert_flash_multihead_matmul.py | 321 ++++++++++ 14 files changed, 2483 insertions(+) create mode 100644 paddle/fluid/framework/ir/trt_cross_multihead_matmul_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/trt_cross_multihead_matmul_fuse_pass.h create mode 100644 paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.h create mode 100644 paddle/fluid/inference/tensorrt/convert/cross_multihead_matmul_op.cc create mode 100644 paddle/fluid/inference/tensorrt/convert/flash_multihead_matmul_op.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cross_multihead_matmul.py create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_flash_multihead_matmul.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 1a84e815e0d..476881f0725 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -130,6 +130,8 @@ target_link_libraries(generate_pass pass_desc_proto) if(WITH_TENSORRT) pass_library(trt_map_matmul_to_mul_pass inference) pass_library(trt_multihead_matmul_fuse_pass inference) + pass_library(trt_flash_multihead_matmul_fuse_pass inference) + pass_library(trt_cross_multihead_matmul_fuse_pass inference) pass_library(trt_skip_layernorm_fuse_pass inference) pass_library(merge_layernorm_fuse_pass inference) pass_library(preln_skip_layernorm_fuse_pass inference) diff --git a/paddle/fluid/framework/ir/trt_cross_multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/trt_cross_multihead_matmul_fuse_pass.cc new file mode 100644 index 00000000000..ee0075f3674 --- /dev/null +++ b/paddle/fluid/framework/ir/trt_cross_multihead_matmul_fuse_pass.cc @@ -0,0 +1,587 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/trt_cross_multihead_matmul_fuse_pass.h" + +#include +#include "math.h" // NOLINT + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" +#ifdef PADDLE_WITH_TENSORRT +#include "paddle/fluid/inference/tensorrt/helper.h" +#endif +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +// input_q input_kv +// |q |k v +// | |-------| +// matmul matmul matmul +// | | | +// reshape reshape reshape +// | | | +// trans trans trans +// |(x) |(y) | +// matmul | +// | | +// scale | +// | | +// softmax |(y) +// |------matmul +// (x) | +// trans +// | +// reshape +// | +// output +// +// -> fused to +// +// input +// | +// cross_multihead_matmul +// | +// output + +PDNode* TrtCrossMultiHeadMatmulPattern::operator()() { + std::unordered_set mul_ops{"mul", "matmul_v2"}; + std::unordered_set matmul_ops{"matmul", "matmul_v2"}; + auto* input0 = pattern->NewNode(input0_repr()); + auto* input1 = pattern->NewNode(input1_repr()); + + input0->assert_is_ops_input(mul_ops); + input1->assert_is_ops_input(mul_ops); + VLOG(5) << "Start match TrtCrossMultiHeadMatmulPattern"; + // First path + auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_ops(mul_ops); + auto* mul0_w_var = pattern->NewNode(mul0_w_repr()) + ->AsInput() + ->assert_is_ops_input(mul_ops, "Y"); + auto* mul0_out_var = + pattern->NewNode(mul0_out_repr())->assert_is_ops_output(mul_ops); + + mul0_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_0 = + pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2"); + + auto* reshape2_0_out_var = + pattern->NewNode(reshape2_0_out_repr())->assert_is_op_output("reshape2"); + reshape2_0_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_0 = + pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); + auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_0_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops, "X"); + + auto* matmul_qk = + pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops); + auto* matmul_qk_out_var = + pattern->NewNode(matmul_qk_out_repr())->assert_is_ops_output(matmul_ops); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale"); + + auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale"); + auto* scale_out_var = + pattern->NewNode(scale_out_repr())->assert_is_op_output("scale"); + scale_out_var->AsIntermediate()->assert_is_op_input("softmax"); + + auto* softmax_qk = + pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); + auto* softmax_qk_out_var = + pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax"); + softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops); + + auto* matmul_qkv = + pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops); + auto* matmul_qkv_out_var = + pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops); + matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_qkv = + pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2"); + auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_qkv = + pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); + auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr()) + ->assert_is_op_output("reshape2"); + + // Second path to matmul + auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_ops(mul_ops); + auto* mul1_w_var = pattern->NewNode(mul1_w_repr()) + ->AsInput() + ->assert_is_ops_input(mul_ops, "Y"); + auto* mul1_out_var = + pattern->NewNode(mul1_out_repr())->assert_is_ops_output(mul_ops); + + mul1_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_1 = + pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2"); + + auto* reshape2_1_out_var = + pattern->NewNode(reshape2_1_out_repr())->assert_is_op_output("reshape2"); + reshape2_1_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_1 = + pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2"); + auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_1_out_var->AsIntermediate()->assert_is_ops_input( + matmul_ops, "Y"); // link to matmul qk + + // Third path to matmul + auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_ops(mul_ops); + auto* mul2_w_var = pattern->NewNode(mul2_w_repr()) + ->AsInput() + ->assert_is_ops_input(mul_ops, "Y"); + auto* mul2_out_var = + pattern->NewNode(mul2_out_repr())->assert_is_ops_output(mul_ops); + + mul2_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_2 = + pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2"); + + auto* reshape2_2_out_var = + pattern->NewNode(reshape2_2_out_repr())->assert_is_op_output("reshape2"); + reshape2_2_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_2 = + pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2"); + auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_2_out_var->AsIntermediate()->assert_is_ops_input( + matmul_ops); // link to matmul qkv + + // Q path + mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var}); + reshape2_0->LinksFrom({mul0_out_var}).LinksTo({reshape2_0_out_var}); + transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); + // K path + mul1->LinksFrom({input1, mul1_w_var}).LinksTo({mul1_out_var}); + + reshape2_1->LinksFrom({mul1_out_var}).LinksTo({reshape2_1_out_var}); + transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var}); + // compute q*k + matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var}) + .LinksTo({matmul_qk_out_var}); + scale->LinksFrom({matmul_qk_out_var}).LinksTo({scale_out_var}); + softmax_qk->LinksFrom({scale_out_var}).LinksTo({softmax_qk_out_var}); + // V path + mul2->LinksFrom({input1, mul2_w_var}).LinksTo({mul2_out_var}); + + reshape2_2->LinksFrom({mul2_out_var}).LinksTo({reshape2_2_out_var}); + transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var}); + // compute q*k*v + matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var}) + .LinksTo({matmul_qkv_out_var}); + transpose2_qkv->LinksFrom({matmul_qkv_out_var}) + .LinksTo({transpose2_qkv_out_var}); + reshape2_qkv->LinksFrom({transpose2_qkv_out_var}) + .LinksTo({reshape2_qkv_out_var}); + + return reshape2_qkv_out_var; +} + +} // namespace patterns + +TrtCrossMultiHeadMatmulFusePass::TrtCrossMultiHeadMatmulFusePass() { + AddOpCompat(OpCompat("mul")) + .AddInput("X") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddInput("Y") // the shape shoule be (N*H, N*H) + .IsTensor() + .End() + .AddOutput("Out") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(2) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H) + .IsType>() + .End(); + + // -->: (B, S, H, N) -> (B, H, S, N) + // <--: (B, H, S, N) -> (B, S, H, N) + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") // {0, 2, 1, 3} + .IsType>() + .End(); + + // QK (B, H, S, N)*(B, H, S, N) -> (B, H, S, S) + // QKV (B, H, S, S)*(B, H, S, N) -> (B, H, S, N) + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsType() // QK(anyvalue, will copy to new op) QKV(1.0) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") // QK(true) QKV(false) + .IsType() + .End(); + + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsBoolEQ(false) + .End() + .AddAttr("trans_y") // QK(true) QKV(false) + .IsType() + .End(); + + AddOpCompat(OpCompat("softmax")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3 + .End(); +} + +int TrtCrossMultiHeadMatmulFusePass::BuildCrossFusion( + Graph* graph, const std::string& name_scope, Scope* scope) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + // Create pattern. + patterns::TrtCrossMultiHeadMatmulPattern multihead_pattern(pattern, + name_scope); + + multihead_pattern(); + auto fuse_creater = [&](Node* input0, + Node* input1, + Node* mul0, + Node* mul1, + Node* mul2, + Node* mul0_out, + Node* mul1_out, + Node* mul2_out, + Node* mul0_w, + Node* mul1_w, + Node* mul2_w, + Node* reshape2, + Node* reshape2_qkv_out, + Node* scale, + Node* scale_out) { + // get Device context + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CPUPlace())); + + auto scale_attr = PADDLE_GET_CONST(float, scale->Op()->GetAttr("scale")); + + // create multihead + OpDesc multihead_op_desc(mul0->Op()->Block()); + auto reshape_desc = reshape2->Op(); + int head_number = + PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) + .at(2); + multihead_op_desc.SetType("cross_multihead_matmul"); + multihead_op_desc.SetInput("Input_q", {input0->Name()}); + multihead_op_desc.SetInput("Input_kv", {input1->Name()}); + + auto* wq_tensor = + scope->FindVar(mul0_w->Name())->GetMutable(); + auto* wk_tensor = + scope->FindVar(mul1_w->Name())->GetMutable(); + auto* wv_tensor = + scope->FindVar(mul2_w->Name())->GetMutable(); + + int hidden_out = wq_tensor->dims()[1]; + int head_size = hidden_out / head_number; + if (abs(scale_attr - 1.0f / sqrt(static_cast(head_size))) > 1e-5) { + VLOG(3) << "scale of muilthead matmul do not fit the requirement of " + "flash attention plugin, Stop fusing."; + return; + } + VLOG(5) << "trt cross attention get wq_tensor name = " << mul0_w->Name() + << "trt cross attention wk_tensor name = " << mul1_w->Name() + << "trt cross attention wv_tensor name = " << mul2_w->Name(); + + auto* wk_data = wk_tensor->data(); + auto* wv_data = wv_tensor->data(); + // combined_w_dims = [in,2,out] + auto combined_w_kv_dims = + phi::make_ddim({wk_tensor->dims()[0], 2, wk_tensor->dims()[1]}); + VLOG(5) << "trt cross attention trt wk_dim in:" << wk_tensor->dims()[0] + << "trt cross attention trt wk_dim out:" << wk_tensor->dims()[1]; + auto* combined_w_kv_desc = mul1_w->Var(); + combined_w_kv_desc->SetShape( + {wk_tensor->dims()[0], 2, wk_tensor->dims()[1]}); + combined_w_kv_desc->SetPersistable(true); + phi::DenseTensor tmp_combined_w_kv_tensor; + tmp_combined_w_kv_tensor.Resize(combined_w_kv_dims); + float* tmp_combined_w_kv_data = + dev_ctx->template HostAlloc(&tmp_combined_w_kv_tensor); + + std::vector w_vec = {wk_data, wv_data}; + int dims_h = combined_w_kv_dims[0], dims_w = combined_w_kv_dims[2]; + // dims_h=in_feature, dims_w=out_feature + // Combine the three fc weights together. + // weight [Hidden_in * 3 * N * H] + for (int i = 0; i < dims_h; i++) { + for (int j = 0; j < 2; j++) { + for (int k = 0; k < dims_w; k++) { + int out_index = i * (2 * dims_w) + j * dims_w + k; + int in_index = i * dims_w + k; + tmp_combined_w_kv_data[out_index] = w_vec[j][in_index]; + } + } + } + wk_tensor->clear(); + wk_tensor->Resize(combined_w_kv_dims); + auto* new_combined_w_kv_data = dev_ctx->template HostAlloc( + wk_tensor, sizeof(float) * wk_tensor->numel()); + memcpy(new_combined_w_kv_data, + tmp_combined_w_kv_data, + sizeof(float) * wk_tensor->numel()); + + scope->EraseVars({mul2_w->Name()}); + + multihead_op_desc.SetInput("W_q", {mul0_w->Name()}); + multihead_op_desc.SetInput("W_kv", {mul1_w->Name()}); + multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()}); + multihead_op_desc.SetAttr("alpha", scale_attr); + multihead_op_desc.SetAttr("head_number", head_number); + + auto* multihead = graph->CreateOpNode(&multihead_op_desc); + IR_NODE_LINK_TO(input0, multihead); + IR_NODE_LINK_TO(input1, multihead); + IR_NODE_LINK_TO(mul0_w, multihead); + IR_NODE_LINK_TO(mul1_w, multihead); + IR_NODE_LINK_TO(multihead, reshape2_qkv_out); + }; + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + // GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(input1, input1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_0_out, reshape2_0_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_0_out, transpose2_0_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_1_out, reshape2_1_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_1_out, transpose2_1_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_2_out, reshape2_2_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_2_out, transpose2_2_out, multihead_pattern); + + // nodes need be removed + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk, matmul_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk_out, matmul_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + softmax_qk_out, softmax_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qkv_out, matmul_qkv_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_qkv_out, reshape2_qkv_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_qkv, transpose2_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_qkv_out, transpose2_qkv_out, multihead_pattern); + + fuse_creater(input0, + input1, + mul0, + mul1, + mul2, + mul0_out, + mul1_out, + mul2_out, + mul0_w, + mul1_w, + mul2_w, + reshape2_0, + reshape2_qkv_out, + scale, + scale_out); + + std::unordered_set marked_nodes({reshape2_0, + reshape2_1, + reshape2_2, + reshape2_0_out, + reshape2_1_out, + reshape2_2_out, + transpose2_0, + transpose2_1, + transpose2_2, + transpose2_0_out, + transpose2_1_out, + transpose2_2_out, + matmul_qk, + matmul_qk_out, + softmax_qk, + softmax_qk_out, + transpose2_qkv, + transpose2_qkv_out, + matmul_qkv, + matmul_qkv_out, + mul0, + mul1, + mul2, + mul0_out, + mul1_out, + mul2_out, + mul2_w, + reshape2_qkv, + scale}); + // Remove unneeded nodes. + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + gpd(graph, handler); + + return fusion_count; +} + +void TrtCrossMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); +#ifdef PADDLE_WITH_TENSORRT + auto trt_version = paddle::inference::tensorrt::GetTrtRuntimeVersion(); + if (std::get<0>(trt_version) * 1000 + std::get<1>(trt_version) * 100 + + std::get<2>(trt_version) * 10 < + 8520) { + VLOG(3) << "Flash attention oss plugin only available for trt version >= " + "8.5.2.2. Stop this pass"; + return; + } +#else + // if no tensorrt, early stop + return; +#endif + bool with_dynamic_shape = Get("with_dynamic_shape"); + if (!with_dynamic_shape) { + VLOG(3) << "Cross attention oss plugin need trt " + "with_dynamic_shape. Stop this pass"; + return; + } + auto* scope = param_scope(); + int fusion_count = BuildCrossFusion(graph, name_scope_, scope); + AddStatis(fusion_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(trt_cross_multihead_matmul_fuse_pass, + paddle::framework::ir::TrtCrossMultiHeadMatmulFusePass); +REGISTER_PASS_CAPABILITY(trt_cross_multihead_matmul_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("mul", 0) + .LE("elementwise_add", 1) + .EQ("reshape2", 0) + .EQ("transpose2", 0) + .EQ("scale", 0) + .LE("matmul", 1) + .EQ("matmul_v2", 0) + .EQ("softmax", 0)); diff --git a/paddle/fluid/framework/ir/trt_cross_multihead_matmul_fuse_pass.h b/paddle/fluid/framework/ir/trt_cross_multihead_matmul_fuse_pass.h new file mode 100644 index 00000000000..b605175513b --- /dev/null +++ b/paddle/fluid/framework/ir/trt_cross_multihead_matmul_fuse_pass.h @@ -0,0 +1,92 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct TrtCrossMultiHeadMatmulPattern : public PatternBase { + TrtCrossMultiHeadMatmulPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, "cross_multihead_matmul") {} + + PDNode* operator()(); + + // declare operator node's name + PATTERN_DECL_NODE(input0); + PATTERN_DECL_NODE(input1); + PATTERN_DECL_NODE(mul0); + PATTERN_DECL_NODE(mul1); + PATTERN_DECL_NODE(mul2); + PATTERN_DECL_NODE(mul0_w); + PATTERN_DECL_NODE(mul1_w); + PATTERN_DECL_NODE(mul2_w); + PATTERN_DECL_NODE(mul0_out); + PATTERN_DECL_NODE(mul1_out); + PATTERN_DECL_NODE(mul2_out); + PATTERN_DECL_NODE(scale); + PATTERN_DECL_NODE(scale_out); + PATTERN_DECL_NODE(reshape2_0); + PATTERN_DECL_NODE(reshape2_1); + PATTERN_DECL_NODE(reshape2_2); + PATTERN_DECL_NODE(reshape2_qkv); + PATTERN_DECL_NODE(reshape2_0_out); + PATTERN_DECL_NODE(reshape2_1_out); + PATTERN_DECL_NODE(reshape2_2_out); + PATTERN_DECL_NODE(reshape2_qkv_out); + PATTERN_DECL_NODE(transpose2_0); + PATTERN_DECL_NODE(transpose2_1); + PATTERN_DECL_NODE(transpose2_2); + PATTERN_DECL_NODE(transpose2_qkv); + PATTERN_DECL_NODE(transpose2_0_out); + PATTERN_DECL_NODE(transpose2_1_out); + PATTERN_DECL_NODE(transpose2_2_out); + PATTERN_DECL_NODE(transpose2_qkv_out); + PATTERN_DECL_NODE(matmul_qk); + PATTERN_DECL_NODE(matmul_qk_out); + PATTERN_DECL_NODE(softmax_qk); + PATTERN_DECL_NODE(softmax_qk_out); + + PATTERN_DECL_NODE(matmul_qkv); + PATTERN_DECL_NODE(matmul_qkv_out); +}; + +} // namespace patterns + +class TrtCrossMultiHeadMatmulFusePass : public FusePassBase { + public: + TrtCrossMultiHeadMatmulFusePass(); + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{"trt_cross_multihead_matmul_fuse"}; + + private: + int BuildCrossFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.cc new file mode 100644 index 00000000000..eb5390d2a99 --- /dev/null +++ b/paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.cc @@ -0,0 +1,579 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.h" + +#include +#include "math.h" // NOLINT + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" +#ifdef PADDLE_WITH_TENSORRT +#include "paddle/fluid/inference/tensorrt/helper.h" +#endif +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +// input +// |q k v +// |------|-------| +// matmul matmul matmul +// | | | +// reshape reshape reshape +// | | | +// trans trans trans +// |(x) |(y) | +// matmul | +// | | +// scale |(y) +// | | +// softmax | +// |------matmul +// (x) | +// trans +// | +// reshape +// | +// output +// +// -> fused to +// +// input +// | +// flash_multihead_matmul +// | +// output + +PDNode* TrtFlashMultiHeadMatmulPattern::operator()() { + std::unordered_set mul_ops{"mul", "matmul_v2"}; + std::unordered_set matmul_ops{"matmul", "matmul_v2"}; + auto* input0 = pattern->NewNode(input0_repr()); + input0->assert_is_ops_input(mul_ops); + VLOG(5) << "Start match TrtFlashMultiHeadMatmulPattern"; + + // First path + auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_ops(mul_ops); + auto* mul0_w_var = pattern->NewNode(mul0_w_repr()) + ->AsInput() + ->assert_is_ops_input(mul_ops, "Y"); + auto* mul0_out_var = + pattern->NewNode(mul0_out_repr())->assert_is_ops_output(mul_ops); + + mul0_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_0 = + pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2"); + + auto* reshape2_0_out_var = + pattern->NewNode(reshape2_0_out_repr())->assert_is_op_output("reshape2"); + reshape2_0_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_0 = + pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); + auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_0_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops, "X"); + + auto* matmul_qk = + pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops); + auto* matmul_qk_out_var = + pattern->NewNode(matmul_qk_out_repr())->assert_is_ops_output(matmul_ops); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale"); + + auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale"); + auto* scale_out_var = + pattern->NewNode(scale_out_repr())->assert_is_op_output("scale"); + scale_out_var->AsIntermediate()->assert_is_op_input("softmax"); + + auto* softmax_qk = + pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); + auto* softmax_qk_out_var = + pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax"); + softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops); + + auto* matmul_qkv = + pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops); + auto* matmul_qkv_out_var = + pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops); + matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_qkv = + pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2"); + auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_qkv = + pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); + auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr()) + ->assert_is_op_output("reshape2"); + + // Second path to matmul + auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_ops(mul_ops); + auto* mul1_w_var = pattern->NewNode(mul1_w_repr()) + ->AsInput() + ->assert_is_ops_input(mul_ops, "Y"); + auto* mul1_out_var = + pattern->NewNode(mul1_out_repr())->assert_is_ops_output(mul_ops); + + mul1_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_1 = + pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2"); + + auto* reshape2_1_out_var = + pattern->NewNode(reshape2_1_out_repr())->assert_is_op_output("reshape2"); + reshape2_1_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_1 = + pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2"); + auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_1_out_var->AsIntermediate()->assert_is_ops_input( + matmul_ops, "Y"); // link to matmul qk + + // Third path to matmul + auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_ops(mul_ops); + auto* mul2_w_var = pattern->NewNode(mul2_w_repr()) + ->AsInput() + ->assert_is_ops_input(mul_ops, "Y"); + auto* mul2_out_var = + pattern->NewNode(mul2_out_repr())->assert_is_ops_output(mul_ops); + + mul2_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_2 = + pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2"); + + auto* reshape2_2_out_var = + pattern->NewNode(reshape2_2_out_repr())->assert_is_op_output("reshape2"); + reshape2_2_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_2 = + pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2"); + auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_2_out_var->AsIntermediate()->assert_is_ops_input( + matmul_ops); // link to matmul qkv + + // Q path + mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var}); + reshape2_0->LinksFrom({mul0_out_var}).LinksTo({reshape2_0_out_var}); + transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); + // K path + mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var}); + + reshape2_1->LinksFrom({mul1_out_var}).LinksTo({reshape2_1_out_var}); + transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var}); + // compute q*k + matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var}) + .LinksTo({matmul_qk_out_var}); + scale->LinksFrom({matmul_qk_out_var}).LinksTo({scale_out_var}); + softmax_qk->LinksFrom({scale_out_var}).LinksTo({softmax_qk_out_var}); + // V path + mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var}); + + reshape2_2->LinksFrom({mul2_out_var}).LinksTo({reshape2_2_out_var}); + transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var}); + // compute q*k*v + matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var}) + .LinksTo({matmul_qkv_out_var}); + transpose2_qkv->LinksFrom({matmul_qkv_out_var}) + .LinksTo({transpose2_qkv_out_var}); + reshape2_qkv->LinksFrom({transpose2_qkv_out_var}) + .LinksTo({reshape2_qkv_out_var}); + + return reshape2_qkv_out_var; +} + +} // namespace patterns + +TrtFlashMultiHeadMatmulFusePass::TrtFlashMultiHeadMatmulFusePass() { + AddOpCompat(OpCompat("mul")) + .AddInput("X") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddInput("Y") // the shape shoule be (N*H, N*H) + .IsTensor() + .End() + .AddOutput("Out") // the shape shoule be (B, S, N*H) + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(2) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H) + .IsType>() + .End(); + + // -->: (B, S, H, N) -> (B, H, S, N) + // <--: (B, H, S, N) -> (B, S, H, N) + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") // {0, 2, 1, 3} + .IsType>() + .End(); + + // QK (B, H, S, N)*(B, H, S, N) -> (B, H, S, S) + // QKV (B, H, S, S)*(B, H, S, N) -> (B, H, S, N) + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsType() // QK(anyvalue, will copy to new op) QKV(1.0) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") // QK(true) QKV(false) + .IsType() + .End(); + + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsBoolEQ(false) + .End() + .AddAttr("trans_y") // QK(true) QKV(false) + .IsType() + .End(); + + AddOpCompat(OpCompat("softmax")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3 + .End(); +} + +int TrtFlashMultiHeadMatmulFusePass::BuildFlashFusion( + Graph* graph, const std::string& name_scope, Scope* scope) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + // Create pattern. + patterns::TrtFlashMultiHeadMatmulPattern multihead_pattern(pattern, + name_scope); + + multihead_pattern(); + auto fuse_creater = [&](Node* input0, + Node* mul0, + Node* mul1, + Node* mul2, + Node* mul0_out, + Node* mul1_out, + Node* mul2_out, + Node* mul0_w, + Node* mul1_w, + Node* mul2_w, + Node* reshape2, + Node* reshape2_qkv_out, + Node* scale, + Node* scale_out) { + // get Device context + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CPUPlace())); + + auto scale_attr = PADDLE_GET_CONST(float, scale->Op()->GetAttr("scale")); + + // create multihead + OpDesc multihead_op_desc(mul0->Op()->Block()); + auto reshape_desc = reshape2->Op(); + int head_number = + PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) + .at(2); + + multihead_op_desc.SetType("flash_multihead_matmul"); + multihead_op_desc.SetInput("Input", {input0->Name()}); + + auto* wq_tensor = + scope->FindVar(mul0_w->Name())->GetMutable(); + auto* wk_tensor = + scope->FindVar(mul1_w->Name())->GetMutable(); + auto* wv_tensor = + scope->FindVar(mul2_w->Name())->GetMutable(); + // check the scale + int hidden_out = wq_tensor->dims()[1]; + int head_size = hidden_out / head_number; + if (abs(scale_attr - 1.0f / sqrt(static_cast(head_size))) > 1e-5) { + VLOG(3) << "scale of muilthead matmul do not fit the requirement of " + "flash attention plugin, Stop fusing."; + return; + } + + float* wq_data = wq_tensor->data(); + float* wk_data = wk_tensor->data(); + float* wv_data = wv_tensor->data(); + // combined_w_dims = [in,3,out] + auto combined_w_dims = + phi::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); + auto* combined_w_desc = mul0_w->Var(); + combined_w_desc->SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); + combined_w_desc->SetPersistable(true); + phi::DenseTensor tmp_combined_w_tensor; + tmp_combined_w_tensor.Resize(combined_w_dims); + float* tmp_combined_w_data = + dev_ctx->template HostAlloc(&tmp_combined_w_tensor); + + std::vector w_vec = {wq_data, wk_data, wv_data}; + int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2]; + // dims_h=in_feature, dims_w=out_feature + // Combine the three fc weights together. + // weight [Hidden_in * 3 * N * H] + for (int i = 0; i < dims_h; i++) { + for (int j = 0; j < 3; j++) { + for (int k = 0; k < dims_w; k++) { + int out_index = i * (3 * dims_w) + j * dims_w + k; + int in_index = i * dims_w + k; + tmp_combined_w_data[out_index] = w_vec[j][in_index]; + } + } + } + // clear weight for reuse + wq_tensor->clear(); + wq_tensor->Resize(combined_w_dims); + + float* new_combined_w_data = dev_ctx->template HostAlloc( + wq_tensor, sizeof(float) * wq_tensor->numel()); + memcpy(new_combined_w_data, + tmp_combined_w_data, + sizeof(float) * wq_tensor->numel()); + + scope->EraseVars({mul1_w->Name(), mul2_w->Name()}); + + multihead_op_desc.SetInput("W", {mul0_w->Name()}); + multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()}); + multihead_op_desc.SetAttr("alpha", scale_attr); + multihead_op_desc.SetAttr("head_number", head_number); + + auto* multihead = graph->CreateOpNode(&multihead_op_desc); + IR_NODE_LINK_TO(input0, multihead); + IR_NODE_LINK_TO(mul0_w, multihead); + IR_NODE_LINK_TO(multihead, reshape2_qkv_out); + }; + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_0_out, reshape2_0_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_0_out, transpose2_0_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_1_out, reshape2_1_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_1_out, transpose2_1_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_2_out, reshape2_2_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_2_out, transpose2_2_out, multihead_pattern); + + // nodes need be removed + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk, matmul_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk_out, matmul_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + softmax_qk_out, softmax_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qkv_out, matmul_qkv_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_qkv_out, reshape2_qkv_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_qkv, transpose2_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_qkv_out, transpose2_qkv_out, multihead_pattern); + + fuse_creater(input0, + mul0, + mul1, + mul2, + mul0_out, + mul1_out, + mul2_out, + mul0_w, + mul1_w, + mul2_w, + reshape2_0, + reshape2_qkv_out, + scale, + scale_out); + + std::unordered_set marked_nodes({reshape2_0, + reshape2_1, + reshape2_2, + reshape2_0_out, + reshape2_1_out, + reshape2_2_out, + transpose2_0, + transpose2_1, + transpose2_2, + transpose2_0_out, + transpose2_1_out, + transpose2_2_out, + matmul_qk, + matmul_qk_out, + softmax_qk, + softmax_qk_out, + transpose2_qkv, + transpose2_qkv_out, + matmul_qkv, + matmul_qkv_out, + mul0, + mul1, + mul2, + mul0_out, + mul1_out, + mul2_out, + mul1_w, + mul2_w, + reshape2_qkv, + scale}); + // Remove unneeded nodes. + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + gpd(graph, handler); + + return fusion_count; +} + +void TrtFlashMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + auto* scope = param_scope(); + +#ifdef PADDLE_WITH_TENSORRT + auto trt_version = paddle::inference::tensorrt::GetTrtRuntimeVersion(); + if (std::get<0>(trt_version) * 1000 + std::get<1>(trt_version) * 100 + + std::get<2>(trt_version) * 10 < + 8520) { + VLOG(3) << "Flash attention oss plugin only available for trt version >= " + "8.5.2.2. Stop this pass"; + return; + } +#else + // if no tensorrt, early stop + return; +#endif + bool with_dynamic_shape = Get("with_dynamic_shape"); + if (!with_dynamic_shape) { + VLOG(3) << "Flash attention oss plugin need trt " + "with_dynamic_shape. Stop this pass"; + return; + } + + int fusion_count = BuildFlashFusion(graph, name_scope_, scope); + AddStatis(fusion_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(trt_flash_multihead_matmul_fuse_pass, + paddle::framework::ir::TrtFlashMultiHeadMatmulFusePass); +REGISTER_PASS_CAPABILITY(trt_flash_multihead_matmul_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("mul", 0) + .LE("elementwise_add", 1) + .EQ("reshape2", 0) + .EQ("transpose2", 0) + .EQ("scale", 0) + .LE("matmul", 1) + .EQ("matmul_v2", 0) + .EQ("softmax", 0)); diff --git a/paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.h b/paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.h new file mode 100644 index 00000000000..7606d99a446 --- /dev/null +++ b/paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.h @@ -0,0 +1,91 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct TrtFlashMultiHeadMatmulPattern : public PatternBase { + TrtFlashMultiHeadMatmulPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, "flash_multihead_matmul") {} + + PDNode* operator()(); + + // declare operator node's name + PATTERN_DECL_NODE(input0); + PATTERN_DECL_NODE(mul0); + PATTERN_DECL_NODE(mul1); + PATTERN_DECL_NODE(mul2); + PATTERN_DECL_NODE(mul0_w); + PATTERN_DECL_NODE(mul1_w); + PATTERN_DECL_NODE(mul2_w); + PATTERN_DECL_NODE(mul0_out); + PATTERN_DECL_NODE(mul1_out); + PATTERN_DECL_NODE(mul2_out); + PATTERN_DECL_NODE(scale); + PATTERN_DECL_NODE(scale_out); + PATTERN_DECL_NODE(reshape2_0); + PATTERN_DECL_NODE(reshape2_1); + PATTERN_DECL_NODE(reshape2_2); + PATTERN_DECL_NODE(reshape2_qkv); + PATTERN_DECL_NODE(reshape2_0_out); + PATTERN_DECL_NODE(reshape2_1_out); + PATTERN_DECL_NODE(reshape2_2_out); + PATTERN_DECL_NODE(reshape2_qkv_out); + PATTERN_DECL_NODE(transpose2_0); + PATTERN_DECL_NODE(transpose2_1); + PATTERN_DECL_NODE(transpose2_2); + PATTERN_DECL_NODE(transpose2_qkv); + PATTERN_DECL_NODE(transpose2_0_out); + PATTERN_DECL_NODE(transpose2_1_out); + PATTERN_DECL_NODE(transpose2_2_out); + PATTERN_DECL_NODE(transpose2_qkv_out); + PATTERN_DECL_NODE(matmul_qk); + PATTERN_DECL_NODE(matmul_qk_out); + PATTERN_DECL_NODE(softmax_qk); + PATTERN_DECL_NODE(softmax_qk_out); + + PATTERN_DECL_NODE(matmul_qkv); + PATTERN_DECL_NODE(matmul_qkv_out); +}; + +} // namespace patterns + +class TrtFlashMultiHeadMatmulFusePass : public FusePassBase { + public: + TrtFlashMultiHeadMatmulFusePass(); + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{"trt_flash_multihead_matmul_fuse"}; + + private: + int BuildFlashFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc index 145aa4ed00f..cf42775c2bd 100644 --- a/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc @@ -631,6 +631,7 @@ PDNode* TrtMultiHeadMatmulV3Pattern::operator()() { return transpose2_2_out_var; } + } // namespace patterns void TrtMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { @@ -1667,6 +1668,7 @@ REGISTER_PASS(trt_multihead_matmul_fuse_pass_v2, paddle::framework::ir::TrtMultiHeadMatmulV2FusePass); REGISTER_PASS(trt_multihead_matmul_fuse_pass_v3, paddle::framework::ir::TrtMultiHeadMatmulV3FusePass); + REGISTER_PASS_CAPABILITY(trt_multihead_matmul_fuse_pass_v2) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() @@ -1677,6 +1679,7 @@ REGISTER_PASS_CAPABILITY(trt_multihead_matmul_fuse_pass_v2) .EQ("scale", 0) .LE("matmul", 1) .EQ("softmax", 0)); + REGISTER_PASS_CAPABILITY(trt_multihead_matmul_fuse_pass_v3) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index a5526c7443e..3a54c7b4ed2 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2427,6 +2427,10 @@ USE_TRT_CONVERTER(expand_v2) USE_TRT_CONVERTER(take_along_axis) USE_TRT_CONVERTER(skip_groupnorm_act) USE_TRT_CONVERTER(preln_groupnorm_act) +#if IS_TRT_VERSION_GE(8522) +USE_TRT_CONVERTER(flash_multihead_matmul) +USE_TRT_CONVERTER(cross_multihead_matmul) +#endif #if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000) USE_TRT_CONVERTER(sparse_fc) USE_TRT_CONVERTER(sparse_multihead_matmul) diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index c6c3e3b05fb..dbb49c1ed39 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -19,6 +19,7 @@ #ifdef PADDLE_WITH_HIP #include #endif + #include #include @@ -103,6 +104,8 @@ const std::vector kTRTSubgraphPasses({ "trt_multihead_matmul_fuse_pass_v3", // "multihead_matmul_roformer_fuse_pass", // "constant_folding_pass", // + "trt_flash_multihead_matmul_fuse_pass", // + "trt_cross_multihead_matmul_fuse_pass", // "vit_attention_fuse_pass", // #if defined _WIN32 // Windows CI is TensorRT7.0. Remove this after upgrading. #else diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 76c74d55d11..7caa8765b8e 100755 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -25,6 +25,8 @@ list( layer_norm_op.cc multihead_matmul_op.cc multihead_matmul_roformer_op.cc + flash_multihead_matmul_op.cc + cross_multihead_matmul_op.cc shuffle_channel_op.cc fill_any_like_op.cc where_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/cross_multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/cross_multihead_matmul_op.cc new file mode 100644 index 00000000000..eda7fe19d21 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/cross_multihead_matmul_op.cc @@ -0,0 +1,277 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See +the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class CrossMultiheadMatMulOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(3) << "convert a cross_multihead_mamul op to a corresponding tensorrt " + "network structure"; + bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + if (engine_->precision() == AnalysisConfig::Precision::kInt8) { + with_fp16 = true; + } + PADDLE_ENFORCE_EQ( + with_fp16, + true, + platform::errors::Unimplemented( + "Trt cross attention oss plugin only support fp16 mode yet.")); + framework::OpDesc op_desc(op, nullptr); + auto* input_q = engine_->GetITensor(op_desc.Input("Input_q").front()); + auto* input_kv = engine_->GetITensor(op_desc.Input("Input_kv").front()); + // auto input_dims = input->getDimensions(); + auto output_name = op_desc.Output("Out")[0]; + + auto weight_q_name = op_desc.Input("W_q").front(); + auto* weight_q_v = scope.FindVar(weight_q_name); + auto* weight_q_t = weight_q_v->GetMutable(); + float* weight_q_data = nullptr; + weight_q_data = const_cast(static_cast( + engine_->GetFp32TrtWeight(weight_q_name, *weight_q_t).get().values)); + const auto& weight_q_dims = weight_q_t->dims(); + int hidden_in_q = weight_q_dims[0]; + int hidden_out_q = weight_q_dims[1]; + int head_number_q = PADDLE_GET_CONST(int, op_desc.GetAttr("head_number")); + int head_size_q = hidden_out_q / head_number_q; + int n_q = hidden_out_q; + auto transpose_weight_q = [](const float* src, + float* dst, + int head_number, + int head_size, + int hidden_in) { + for (int hn = 0; hn < head_number; hn++) { + for (int hs = 0; hs < head_size; hs++) { + for (int hi = 0; hi < hidden_in; hi++) { + int out_index = hn * head_size * hidden_in + hs * hidden_in + hi; + int in_index = hi * head_number * head_size + hn * head_size + hs; + dst[out_index] = src[in_index]; + } + } + } + }; + std::vector weight_q_data_tmp; + weight_q_data_tmp.reserve(weight_q_t->numel()); + memcpy(weight_q_data_tmp.data(), + weight_q_data, + weight_q_t->numel() * sizeof(float)); + transpose_weight_q(weight_q_data_tmp.data(), + weight_q_data, + head_number_q, + head_size_q, + hidden_in_q); + + nvinfer1::Weights weight_q{nvinfer1::DataType::kFLOAT, + static_cast(weight_q_data), + static_cast(weight_q_t->numel())}; + nvinfer1::Weights bias_q{}; + // add shuffle for FullyConnected layer + std::vector reshape_before_fc_q_shape_tensor; + nvinfer1::ITensor* input_q_shape_tensor = Shape(input_q); + for (int i = 0; i < 5; i++) { + reshape_before_fc_q_shape_tensor.push_back(Add1DConstantLayer(1)); + } + for (int i = 0; i < 3; i++) { + reshape_before_fc_q_shape_tensor[i] = + GetEleTensorOfShape(input_q_shape_tensor, i); + } + auto* reshape_before_fc_q_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input_q); + reshape_before_fc_q_layer->setInput( + 1, *Concat(reshape_before_fc_q_shape_tensor)); + reshape_before_fc_q_layer->setName( + ("shuffle_before_fc_q_multihead_matmul(Output: " + output_name + ")") + .c_str()); + + nvinfer1::ILayer* fc_q_layer = nullptr; + fc_q_layer = TRT_ENGINE_ADD_LAYER(engine_, + FullyConnected, + *reshape_before_fc_q_layer->getOutput(0), + n_q, + weight_q, + bias_q); + fc_q_layer->setName( + ("multihead_mamul_fc_q(Output: " + output_name + ")").c_str()); + + // add shuffle for fc layer + auto* reshape_after_fc_q_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *fc_q_layer->getOutput(0)); + std::vector mha_input_q_tensor_shape; + for (int i = 0; i < 4; i++) { + mha_input_q_tensor_shape.push_back(Add1DConstantLayer(1)); + } + mha_input_q_tensor_shape[0] = GetEleTensorOfShape(input_q_shape_tensor, 0); + mha_input_q_tensor_shape[1] = GetEleTensorOfShape(input_q_shape_tensor, 1); + mha_input_q_tensor_shape[2] = Add1DConstantLayer(head_number_q); + mha_input_q_tensor_shape[3] = Add1DConstantLayer(head_size_q); + reshape_after_fc_q_layer->setInput(1, *Concat(mha_input_q_tensor_shape)); + reshape_after_fc_q_layer->setName( + ("shuffle_after_fc_q_multihead_matmul(Output: " + output_name + ")") + .c_str()); + + auto weight_kv_name = op_desc.Input("W_kv").front(); + auto* weight_kv_v = scope.FindVar(weight_kv_name); + auto* weight_kv_t = weight_kv_v->GetMutable(); + float* weight_kv_data = nullptr; + weight_kv_data = const_cast(static_cast( + engine_->GetFp32TrtWeight(weight_kv_name, *weight_kv_t).get().values)); + + // (hidden_in, 2, hidden_out) + const auto& weight_kv_dims = weight_kv_t->dims(); + + int hidden_in = weight_kv_dims[0]; // channels_in + int two = weight_kv_dims[1]; // three + int hidden_out = weight_kv_dims[2]; // channels_out + int head_number = PADDLE_GET_CONST(int, op_desc.GetAttr("head_number")); + int head_size = hidden_out / head_number; + + int n = two * hidden_out; + nvinfer1::ILayer* layer = nullptr; + + // [hidden_in, 3, head_number, head_size] + // -> [head_number, 3, head_size, hidden_in] + auto transpose_weight = [](const float* src, + float* dst, + int two, + int head_number, + int head_size, + int hidden_in) { + for (int hn = 0; hn < head_number; hn++) { + for (int t = 0; t < two; t++) { + for (int hs = 0; hs < head_size; hs++) { + for (int hi = 0; hi < hidden_in; hi++) { + int out_index = hn * two * head_size * hidden_in + + t * head_size * hidden_in + hs * hidden_in + hi; + int in_index = hi * two * head_number * head_size + + t * head_number * head_size + hn * head_size + hs; + dst[out_index] = src[in_index]; + } + } + } + } + }; + std::vector weight_kv_data_tmp; + + weight_kv_data_tmp.reserve(weight_kv_t->numel()); + memcpy(weight_kv_data_tmp.data(), + weight_kv_data, + weight_kv_t->numel() * sizeof(float)); + transpose_weight(weight_kv_data_tmp.data(), + weight_kv_data, + two, + head_number, + head_size, + hidden_in); + nvinfer1::Weights weight_kv{nvinfer1::DataType::kFLOAT, + static_cast(weight_kv_data), + static_cast(weight_kv_t->numel())}; + nvinfer1::Weights bias_kv{}; + + // add shuffle for FullyConnected layer + std::vector reshape_before_fc_shape_tensor; + nvinfer1::ITensor* input_shape_tensor = Shape(input_kv); + for (int i = 0; i < 5; i++) { + reshape_before_fc_shape_tensor.push_back(Add1DConstantLayer(1)); + } + for (int i = 0; i < 3; i++) { + reshape_before_fc_shape_tensor[i] = + GetEleTensorOfShape(input_shape_tensor, i); + } + auto* reshape_before_fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input_kv); + reshape_before_fc_layer->setInput(1, + *Concat(reshape_before_fc_shape_tensor)); + reshape_before_fc_layer->setName( + ("shuffle_before_fc_multihead_matmul(Output: " + output_name + ")") + .c_str()); + + nvinfer1::ILayer* fc_layer = nullptr; + fc_layer = TRT_ENGINE_ADD_LAYER(engine_, + FullyConnected, + *reshape_before_fc_layer->getOutput(0), + n, + weight_kv, + bias_kv); + fc_layer->setName( + ("multihead_mamul_fc(Output: " + output_name + ")").c_str()); + + // add shuffle for fc layer + auto* reshape_after_fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *fc_layer->getOutput(0)); + std::vector mha_input_tensor_shape; + for (int i = 0; i < 5; i++) { + mha_input_tensor_shape.push_back(Add1DConstantLayer(1)); + } + mha_input_tensor_shape[0] = GetEleTensorOfShape(input_shape_tensor, 0); + mha_input_tensor_shape[1] = GetEleTensorOfShape(input_shape_tensor, 1); + mha_input_tensor_shape[2] = Add1DConstantLayer(head_number); + mha_input_tensor_shape[3] = Add1DConstantLayer(2); + mha_input_tensor_shape[4] = Add1DConstantLayer(head_size); + reshape_after_fc_layer->setInput(1, *Concat(mha_input_tensor_shape)); + reshape_after_fc_layer->setName( + ("shuffle_after_fc_multihead_matmul(Output: " + output_name + ")") + .c_str()); + + auto creator = GetPluginRegistry()->getPluginCreator("fMHCA", "1"); + assert(creator != nullptr); + std::vector fields{}; + nvinfer1::PluginFieldCollection* plugin_collection = + static_cast( + malloc(sizeof(*plugin_collection) + + fields.size() * + sizeof(nvinfer1::PluginField))); // remember to free + + plugin_collection->nbFields = static_cast(fields.size()); + plugin_collection->fields = fields.data(); + auto plugin = creator->createPlugin("fMHA_V2", plugin_collection); + free(plugin_collection); + std::vector plugin_inputs; + plugin_inputs.emplace_back(reshape_after_fc_q_layer->getOutput(0)); + plugin_inputs.emplace_back(reshape_after_fc_layer->getOutput(0)); + auto plugin_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + + // add shuffle + nvinfer1::ITensor* batch_tensor = + GetEleTensorOfShape(input_q_shape_tensor, 0); + nvinfer1::ITensor* length_tensor = + GetEleTensorOfShape(input_q_shape_tensor, 1); + auto* reshape_after_mha_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0)); + std::vector reshape_tensor; + reshape_tensor.push_back(batch_tensor); + reshape_tensor.push_back(length_tensor); + reshape_tensor.push_back(Add1DConstantLayer(-1)); + reshape_after_mha_layer->setInput(1, *Concat(reshape_tensor)); + reshape_after_mha_layer->setName( + ("shuffle_last_multihead_matmul(Output: " + output_name + ")").c_str()); + // return + layer = reshape_after_mha_layer; + RreplenishLayerAndOutput( + layer, "cross_multihead_matmul", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(cross_multihead_matmul, + CrossMultiheadMatMulOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/flash_multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/flash_multihead_matmul_op.cc new file mode 100644 index 00000000000..2544a1fdff2 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/flash_multihead_matmul_op.cc @@ -0,0 +1,190 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See +the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class FlashMultiheadMatMulOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(3) << "convert a flash_multihead_mamul op to a corresponding tensorrt " + "network structure"; + + bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + if (engine_->precision() == AnalysisConfig::Precision::kInt8) { + with_fp16 = true; + } + PADDLE_ENFORCE_EQ( + with_fp16, + true, + platform::errors::Unimplemented( + "Trt flash attention oss plugin only support fp16 mode yet.")); + + framework::OpDesc op_desc(op, nullptr); + auto* input = engine_->GetITensor(op_desc.Input("Input").front()); + + auto weight_name = op_desc.Input("W").front(); + auto* weight_v = scope.FindVar(weight_name); + auto* weight_t = weight_v->GetMutable(); + float* weight_data = nullptr; + weight_data = const_cast(static_cast( + engine_->GetFp32TrtWeight(weight_name, *weight_t).get().values)); + + // (hidden_in, 3, hidden_out) + const auto& weight_dims = weight_t->dims(); + + int hidden_in = weight_dims[0]; // channels_in + int three = weight_dims[1]; // three + int hidden_out = weight_dims[2]; // channels_out + int head_number = PADDLE_GET_CONST(int, op_desc.GetAttr("head_number")); + int head_size = hidden_out / head_number; + + int n = three * hidden_out; + nvinfer1::ILayer* layer = nullptr; + auto output_name = op_desc.Output("Out")[0]; + + // [hidden_in, 3, head_number, head_size] + // -> [head_number, 3, head_size, hidden_in] + auto transpose_weight = [](const float* src, + float* dst, + int three, + int head_number, + int head_size, + int hidden_in) { + for (int hn = 0; hn < head_number; hn++) { + for (int t = 0; t < three; t++) { + for (int hs = 0; hs < head_size; hs++) { + for (int hi = 0; hi < hidden_in; hi++) { + int out_index = hn * three * head_size * hidden_in + + t * head_size * hidden_in + hs * hidden_in + hi; + int in_index = hi * three * head_number * head_size + + t * head_number * head_size + hn * head_size + hs; + dst[out_index] = src[in_index]; + } + } + } + } + }; + std::vector weight_data_tmp; + weight_data_tmp.reserve(weight_t->numel()); + memcpy( + weight_data_tmp.data(), weight_data, weight_t->numel() * sizeof(float)); + + transpose_weight(weight_data_tmp.data(), + weight_data, + three, + head_number, + head_size, + hidden_in); + nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT, + static_cast(weight_data), + static_cast(weight_t->numel())}; + nvinfer1::Weights bias{}; + // add shuffle for FullyConnected layer + std::vector reshape_before_fc_shape_tensor; + nvinfer1::ITensor* input_shape_tensor = Shape(input); + + for (int i = 0; i < 5; i++) { + reshape_before_fc_shape_tensor.push_back(Add1DConstantLayer(1)); + } + for (int i = 0; i < 3; i++) { + reshape_before_fc_shape_tensor[i] = + GetEleTensorOfShape(input_shape_tensor, i); + } + + auto* reshape_before_fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + reshape_before_fc_layer->setInput(1, + *Concat(reshape_before_fc_shape_tensor)); + reshape_before_fc_layer->setName( + ("shuffle_before_fc_multihead_matmul(Output: " + output_name + ")") + .c_str()); + nvinfer1::ILayer* fc_layer = nullptr; + fc_layer = TRT_ENGINE_ADD_LAYER(engine_, + FullyConnected, + *reshape_before_fc_layer->getOutput(0), + n, + weight, + bias); + fc_layer->setName( + ("multihead_mamul_fc(Output: " + output_name + ")").c_str()); + + // add shuffle for fc layer + + auto* reshape_after_fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *fc_layer->getOutput(0)); + std::vector mha_input_tensor_shape; + for (int i = 0; i < 5; i++) { + mha_input_tensor_shape.push_back(Add1DConstantLayer(1)); + } + mha_input_tensor_shape[0] = GetEleTensorOfShape(input_shape_tensor, 0); + mha_input_tensor_shape[1] = GetEleTensorOfShape(input_shape_tensor, 1); + mha_input_tensor_shape[2] = Add1DConstantLayer(head_number); + mha_input_tensor_shape[3] = Add1DConstantLayer(3); + mha_input_tensor_shape[4] = Add1DConstantLayer(head_size); + reshape_after_fc_layer->setInput(1, *Concat(mha_input_tensor_shape)); + reshape_after_fc_layer->setName( + ("shuffle_after_fc_multihead_matmul(Output: " + output_name + ")") + .c_str()); + auto creator = GetPluginRegistry()->getPluginCreator("fMHA_V2", "1"); + assert(creator != nullptr); + std::vector fields{}; + nvinfer1::PluginFieldCollection* plugin_collection = + static_cast( + malloc(sizeof(*plugin_collection) + + fields.size() * + sizeof(nvinfer1::PluginField))); // remember to free + + plugin_collection->nbFields = static_cast(fields.size()); + plugin_collection->fields = fields.data(); + auto plugin = creator->createPlugin("fMHA_V2", plugin_collection); + free(plugin_collection); + std::vector plugin_inputs; + plugin_inputs.emplace_back(reshape_after_fc_layer->getOutput(0)); + auto plugin_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + + // add shuffle + + nvinfer1::ITensor* batch_tensor = + GetEleTensorOfShape(input_shape_tensor, 0); + nvinfer1::ITensor* length_tensor = + GetEleTensorOfShape(input_shape_tensor, 1); + auto* reshape_after_mha_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0)); + std::vector reshape_tensor; + reshape_tensor.push_back(batch_tensor); + reshape_tensor.push_back(length_tensor); + reshape_tensor.push_back(Add1DConstantLayer(-1)); + reshape_after_mha_layer->setInput(1, *Concat(reshape_tensor)); + reshape_after_mha_layer->setName( + ("shuffle_last_multihead_matmul(Output: " + output_name + ")").c_str()); + // return + layer = reshape_after_mha_layer; + RreplenishLayerAndOutput( + layer, "flash_multihead_matmul", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(flash_multihead_matmul, + FlashMultiheadMatMulOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 4dd85feb0e6..fbbd77a4c98 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -66,6 +66,12 @@ struct SimpleOpTypeSetTeller : public Teller { teller_set.insert("sparse_multihead_matmul"); int8_teller_set.insert("sparse_multihead_matmul"); #endif +#if IS_TRT_VERSION_GE(8522) + teller_set.insert("flash_multihead_matmul"); + int8_teller_set.insert("flash_multihead_matmul"); + teller_set.insert("cross_multihead_matmul"); + int8_teller_set.insert("cross_multihead_matmul"); +#endif #if IS_TRT_VERSION_GE(8200) teller_set.insert("round"); int8_teller_set.insert("round"); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cross_multihead_matmul.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cross_multihead_matmul.py new file mode 100644 index 00000000000..c427a94772b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cross_multihead_matmul.py @@ -0,0 +1,326 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import List + +import numpy as np +from program_config import ProgramConfig, TensorConfig +from trt_layer_auto_scan_test import SkipReasons, TrtLayerAutoScanTest + +import paddle.inference as paddle_infer + + +class TrtConvertCrossMultiHeadMatmulTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + ver = paddle_infer.get_trt_compile_version() + if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8520: + return False + return True + + def sample_program_configs(self): + def generate_input1(batch, dim1): + return np.random.random((batch, dim1, 320)).astype(np.float32) / 10 + + def generate_input2(batch, dim2): + return np.random.random((batch, dim2, 768)).astype(np.float32) / 10 + + def generate_weight1(): + return np.random.random((320, 320)).astype(np.float32) / 10 + + def generate_weight2(): + return np.random.random((768, 320)).astype(np.float32) / 10 + + for batch in [1, 2]: + self.batch = batch + for reshape_shape in [[0, 0, 8, 40]]: + for dim1 in [4096]: + for dim2 in [768]: + dics = [ + {"trans_x": False, "trans_y": False}, + {"shape": reshape_shape}, + {"axis": [0, 2, 1, 3]}, + {"trans_x": False, "trans_y": False}, + {"shape": reshape_shape}, + {"axis": [0, 2, 1, 3]}, + {"trans_x": False, "trans_y": False}, + {"shape": reshape_shape}, + {"axis": [0, 2, 1, 3]}, + { + "trans_x": False, + "trans_y": True, + }, + { + "scale": 0.15811388194561005, + "bias": 0.0, + "bias_after_scale": True, + }, + {"axis": -1, "is_test": True}, + {"trans_x": False, "trans_y": False}, + {"axis": [0, 2, 1, 3]}, + {"shape": [0, 0, 320]}, + ] + + ops_config = [ + { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["input_data1"], + "Y": ["mul1_weight"], + }, + "op_outputs": {"Out": ["mul1_output"]}, + "op_attrs": dics[0], + }, + { + "op_type": "reshape2", + "op_inputs": { + "X": ["mul1_output"], + }, + "op_outputs": { + "Out": ["reshape21_output"], + "XShape": ["reshape21_output_xshape"], + }, + "op_attrs": dics[1], + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["reshape21_output"]}, + "op_outputs": { + "Out": ["transpose21_output"], + "XShape": ["transpose21_output_xshape"], + }, + "op_attrs": dics[2], + }, + { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["input_data2"], + "Y": ["mul2_weight"], + }, + "op_outputs": {"Out": ["mul2_output"]}, + "op_attrs": dics[3], + }, + { + "op_type": "reshape2", + "op_inputs": {"X": ["mul2_output"]}, + "op_outputs": { + "Out": ["reshape22_output"], + "XShape": ["reshape22_output_xshape"], + }, + "op_attrs": dics[4], + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["reshape22_output"]}, + "op_outputs": { + "Out": ["transpose22_output"], + "XShape": ["transpose22_output_xshape"], + }, + "op_attrs": dics[5], + }, + { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["input_data2"], + "Y": ["mul3_weight"], + }, + "op_outputs": {"Out": ["mul3_output"]}, + "op_attrs": dics[6], + }, + { + "op_type": "reshape2", + "op_inputs": {"X": ["mul3_output"]}, + "op_outputs": { + "Out": ["reshape23_output"], + "XShape": ["reshape23_output_xshape"], + }, + "op_attrs": dics[7], + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["reshape23_output"]}, + "op_outputs": { + "Out": ["transpose23_output"], + "XShape": ["transpose23_output_xshape"], + }, + "op_attrs": dics[8], + }, + { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["transpose21_output"], + "Y": ["transpose22_output"], + }, + "op_outputs": {"Out": ["matmul1_output"]}, + "op_attrs": dics[9], + }, + { + "op_type": "scale", + "op_inputs": { + "X": ["matmul1_output"], + }, + "op_outputs": {"Out": ["scale_output"]}, + "op_attrs": dics[10], + }, + { + "op_type": "softmax", + "op_inputs": {"X": ["scale_output"]}, + "op_outputs": {"Out": ["softmax_output"]}, + "op_attrs": dics[11], + }, + { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["softmax_output"], + "Y": ["transpose23_output"], + }, + "op_outputs": {"Out": ["matmul2_output"]}, + "op_attrs": dics[12], + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["matmul2_output"]}, + "op_outputs": { + "Out": ["transpose24_output"], + "XShape": ["transpose24_output_xshape"], + }, + "op_attrs": dics[13], + }, + { + "op_type": "reshape2", + "op_inputs": {"X": ["transpose24_output"]}, + "op_outputs": { + "Out": ["reshape24_output"], + "XShape": ["reshape24_output_xshape"], + }, + "op_attrs": dics[14], + }, + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "mul1_weight": TensorConfig( + data_gen=partial(generate_weight1) + ), + "mul2_weight": TensorConfig( + data_gen=partial(generate_weight2) + ), + "mul3_weight": TensorConfig( + data_gen=partial(generate_weight2) + ), + }, + inputs={ + "input_data1": TensorConfig( + data_gen=partial( + generate_input1, batch, dim1 + ) + ), + "input_data2": TensorConfig( + data_gen=partial( + generate_input2, batch, dim2 + ) + ), + }, + outputs=["reshape24_output"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + # The last dim of input1 and input2 should be static. + self.dynamic_shape.min_input_shape = { + "input_data1": [1, 4096, 320], + "input_data2": [1, 77, 768], + } + self.dynamic_shape.max_input_shape = { + "input_data1": [8, 4096, 320], + "input_data2": [8, 77, 768], + } + self.dynamic_shape.opt_input_shape = { + "input_data1": [2, 4096, 320], + "input_data2": [2, 77, 768], + } + + def clear_dynamic_shape(): + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + self.trt_param.workspace_size = 2013265920 + yield self.create_inference_config(), (1, 4), (1e-5, 1e-5) + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), (1, 4), (1e-2, 1e-3) + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + self.trt_param.workspace_size = 2013265920 + yield self.create_inference_config(), (1, 3), (1e-5, 1e-4) + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), (1, 3), (1e-2, 1e-3) + + def add_skip_trt_case(self): + def teller1(program_config, predictor_config): + if self.dynamic_shape.min_input_shape == {}: + return True + return False + + self.add_skip_case( + teller1, + SkipReasons.TRT_NOT_IMPLEMENTED, + "TThe cross attention trt oss plugin do not support static shape yet", + ) + + def teller2(program_config, predictor_config): + if self.trt_param.precision == paddle_infer.PrecisionType.Float32: + return True + return False + + self.add_skip_case( + teller2, + SkipReasons.TRT_NOT_IMPLEMENTED, + "The cross attention trt oss plugin do not support fp32 yet", + ) + + def teller3(program_config, predictor_config): + if self.trt_param.precision == paddle_infer.PrecisionType.Int8: + return True + return False + + self.add_skip_case( + teller3, + SkipReasons.TRT_NOT_IMPLEMENTED, + "The cross attention trt oss plugin do not support int8 yet.", + ) + + def test(self): + self.add_skip_trt_case() + self.run_test() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_flash_multihead_matmul.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_flash_multihead_matmul.py new file mode 100644 index 00000000000..ebede2dce89 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_flash_multihead_matmul.py @@ -0,0 +1,321 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import List + +import numpy as np +from program_config import ProgramConfig, TensorConfig +from trt_layer_auto_scan_test import SkipReasons, TrtLayerAutoScanTest + +import paddle.inference as paddle_infer + + +class TrtConvertFlashMultiHeadMatmulTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + ver = paddle_infer.get_trt_compile_version() + if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8520: + return False + return True + + def sample_program_configs(self): + def generate_input1(batch, dim1): + return np.random.rand(batch, dim1, 320).astype(np.float32) / 10 + + def generate_weight1(): + return np.random.rand(320, 320).astype(np.float32) / 10 + + for batch in [1, 2]: + self.batch = batch + for reshape_shape in [[0, 0, 8, 40]]: + for dim1 in [4096]: + dics = [ + {"trans_x": False, "trans_y": False}, # 0,matmul_v2_q + {"shape": reshape_shape}, # 1,reshape_q + { + "axis": [0, 2, 1, 3], + "data_format": "AnyLayout", + }, # 2,trans_q + {"trans_x": False, "trans_y": False}, # 3,matmul_v2_k + {"shape": reshape_shape}, # 4,reshape_k + { + "axis": [0, 2, 1, 3], + "data_format": "AnyLayout", + }, # 5,trans_k + {"trans_x": False, "trans_y": False}, # 6,matmul_v2_q + {"shape": reshape_shape}, # 7,reshape_q + { + "axis": [0, 2, 1, 3], + "data_format": "AnyLayout", + }, # 8,trans_q + { # 9,matmul_qk + "trans_x": False, + "trans_y": True, + }, + { # 10,scale + "scale": 0.15811388194561005, + "bias": 0.0, + "bias_after_scale": True, + }, + {"axis": -1, "is_test": True}, # 11,softmax + {"trans_x": False, "trans_y": False}, # 12,matmul_qkv + { + "axis": [0, 2, 1, 3], + "data_format": "AnyLayout", + }, # 13,trans_qkv + {"shape": [0, 0, 320]}, # 14,reshape_qkv + ] + + ops_config = [ + { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["input_data1"], + "Y": ["mul1_weight"], + }, + "op_outputs": {"Out": ["mul1_output"]}, + "op_attrs": dics[0], + }, + { + "op_type": "reshape2", + "op_inputs": { + "X": ["mul1_output"], + }, + "op_outputs": { + "Out": ["reshape21_output"], + "XShape": ["reshape21_output_xshape"], + }, + "op_attrs": dics[1], + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["reshape21_output"]}, + "op_outputs": { + "Out": ["transpose21_output"], + "XShape": ["transpose21_output_xshape"], + }, + "op_attrs": dics[2], + }, + { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["input_data1"], + "Y": ["mul2_weight"], + }, + "op_outputs": {"Out": ["mul2_output"]}, + "op_attrs": dics[3], + }, + { + "op_type": "reshape2", + "op_inputs": {"X": ["mul2_output"]}, + "op_outputs": { + "Out": ["reshape22_output"], + "XShape": ["reshape22_output_xshape"], + }, + "op_attrs": dics[4], + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["reshape22_output"]}, + "op_outputs": { + "Out": ["transpose22_output"], + "XShape": ["transpose22_output_xshape"], + }, + "op_attrs": dics[5], + }, + { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["input_data1"], + "Y": ["mul3_weight"], + }, + "op_outputs": {"Out": ["mul3_output"]}, + "op_attrs": dics[6], + }, + { + "op_type": "reshape2", + "op_inputs": {"X": ["mul3_output"]}, + "op_outputs": { + "Out": ["reshape23_output"], + "XShape": ["reshape23_output_xshape"], + }, + "op_attrs": dics[7], + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["reshape23_output"]}, + "op_outputs": { + "Out": ["transpose23_output"], + "XShape": ["transpose23_output_xshape"], + }, + "op_attrs": dics[8], + }, + { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["transpose21_output"], + "Y": ["transpose22_output"], + }, + "op_outputs": {"Out": ["matmul1_output"]}, + "op_attrs": dics[9], + }, + { + "op_type": "scale", + "op_inputs": { + "X": ["matmul1_output"], + }, + "op_outputs": {"Out": ["scale_output"]}, + "op_attrs": dics[10], + }, + { + "op_type": "softmax", + "op_inputs": {"X": ["scale_output"]}, + "op_outputs": {"Out": ["softmax_output"]}, + "op_attrs": dics[11], + }, + { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["softmax_output"], + "Y": ["transpose23_output"], + }, + "op_outputs": {"Out": ["matmul2_output"]}, + "op_attrs": dics[12], + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["matmul2_output"]}, + "op_outputs": { + "Out": ["transpose24_output"], + "XShape": ["transpose24_output_xshape"], + }, + "op_attrs": dics[13], + }, + { + "op_type": "reshape2", + "op_inputs": {"X": ["transpose24_output"]}, + "op_outputs": { + "Out": ["reshape24_output"], + "XShape": ["reshape24_output_xshape"], + }, + "op_attrs": dics[14], + }, + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "mul1_weight": TensorConfig( + data_gen=partial(generate_weight1) + ), + "mul2_weight": TensorConfig( + data_gen=partial(generate_weight1) + ), + "mul3_weight": TensorConfig( + data_gen=partial(generate_weight1) + ), + }, + inputs={ + "input_data1": TensorConfig( + data_gen=partial(generate_input1, batch, dim1) + ) + }, + outputs=["reshape24_output"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + # The last dim of input1 and input2 should be static. + self.dynamic_shape.min_input_shape = { + "input_data1": [1, 4096, 320], + } + self.dynamic_shape.max_input_shape = { + "input_data1": [16, 4096, 320], + } + self.dynamic_shape.opt_input_shape = { + "input_data1": [2, 4096, 320], + } + + def clear_dynamic_shape(): + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + self.trt_param.workspace_size = 2013265920 + yield self.create_inference_config(), (1, 2), (1e-5, 1e-5) + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), (1, 2), (1e-3, 1e-3) + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + self.trt_param.workspace_size = 2013265920 + yield self.create_inference_config(), (1, 2), (1e-5, 1e-4) + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), (1, 2), (1e-2, 1e-3) + + def add_skip_trt_case(self): + def teller1(program_config, predictor_config): + if self.dynamic_shape.min_input_shape == {}: + return True + return False + + self.add_skip_case( + teller1, + SkipReasons.TRT_NOT_IMPLEMENTED, + "TThe flash attention trt oss plugin do not support static shape yet", + ) + + def teller2(program_config, predictor_config): + if self.trt_param.precision == paddle_infer.PrecisionType.Float32: + return True + return False + + self.add_skip_case( + teller2, + SkipReasons.TRT_NOT_IMPLEMENTED, + "The flash attention trt oss plugin do not support fp32 yet", + ) + + def teller3(program_config, predictor_config): + if self.trt_param.precision == paddle_infer.PrecisionType.Int8: + return True + return False + + self.add_skip_case( + teller3, + SkipReasons.TRT_NOT_IMPLEMENTED, + "The flash attention trt oss plugin do not support int8 yet.", + ) + + def test(self): + self.add_skip_trt_case() + self.run_test() + + +if __name__ == "__main__": + unittest.main() -- GitLab