From 6934ac797f6ae6d3c83529af2f510ac194452d66 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> Date: Mon, 10 Apr 2023 15:23:52 +0800 Subject: [PATCH] [Paddle Inference] Support two inputs of multihead attention named qk_multihead. (#52455) * Support two inputs of multihead attention named qk_multihead --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../ir/trt_qk_multihead_matmul_fuse_pass.cc | 591 ++++++++++++++++++ .../ir/trt_qk_multihead_matmul_fuse_pass.h | 104 +++ .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/api/paddle_pass_builder.cc | 1 + .../inference/tensorrt/convert/CMakeLists.txt | 1 + .../convert/qk_multihead_matmul_op.cc | 301 +++++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 2 + .../test_trt_convert_qk_multihead_matmul.py | 385 ++++++++++++ 9 files changed, 1387 insertions(+) create mode 100644 paddle/fluid/framework/ir/trt_qk_multihead_matmul_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/trt_qk_multihead_matmul_fuse_pass.h create mode 100644 paddle/fluid/inference/tensorrt/convert/qk_multihead_matmul_op.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_qk_multihead_matmul.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index fbec6488568..91c3ba6d608 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -134,6 +134,7 @@ if(WITH_TENSORRT) pass_library(trt_multihead_matmul_fuse_pass inference) pass_library(trt_flash_multihead_matmul_fuse_pass inference) pass_library(trt_cross_multihead_matmul_fuse_pass inference) + pass_library(trt_qk_multihead_matmul_fuse_pass inference) pass_library(trt_skip_layernorm_fuse_pass inference) pass_library(merge_layernorm_fuse_pass inference) pass_library(preln_skip_layernorm_fuse_pass inference) diff --git a/paddle/fluid/framework/ir/trt_qk_multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/trt_qk_multihead_matmul_fuse_pass.cc new file mode 100644 index 00000000000..df1476e9db3 --- /dev/null +++ b/paddle/fluid/framework/ir/trt_qk_multihead_matmul_fuse_pass.cc @@ -0,0 +1,591 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/trt_qk_multihead_matmul_fuse_pass.h" + +#include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" +#ifdef PADDLE_WITH_TENSORRT +#include "paddle/fluid/inference/tensorrt/helper.h" +#endif +#include "paddle/phi/kernels/funcs/blas/blas.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +// input_qk input_v +// |q |k v +// |------| | +// matmul matmul matmul +// | | | +// reshape reshape reshape +// | | | +// trans trans trans +// |(x) |(x) | +// matmul | +// | | +// scale | +// | | +// softmax |(y) +// |------matmul +// | +// trans +// | +// reshape +// | +// output +// +// -> fused to +// +// input_qk intput_v +// | +// qk_multihead_matmul +// | +// output + +PDNode* TrtQKMultiHeadMatmulPattern::operator()() { + std::unordered_set mul_ops{"mul", "matmul_v2"}; + std::unordered_set matmul_ops{"matmul", "matmul_v2"}; + auto* input0 = pattern->NewNode(input0_repr()); + auto* input1 = pattern->NewNode(input1_repr()); + + input0->assert_is_ops_input(mul_ops); + input1->assert_is_ops_input(mul_ops); + VLOG(5) << "Start match TrtQKMultiHeadMatmulPattern"; + + // First path + auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_ops(mul_ops); + auto* mul0_w_var = pattern->NewNode(mul0_w_repr()) + ->AsInput() + ->assert_is_ops_input(mul_ops, "Y"); + auto* mul0_out_var = pattern->NewNode(mul0_out_repr()) + ->assert_is_ops_output(mul_ops) + ->assert_is_op_input("elementwise_add", "X") + ->AsIntermediate(); + + auto* elementwise0 = + pattern->NewNode(elementwise0_repr())->assert_is_op("elementwise_add"); + auto* elementwise0_w = pattern->NewNode(elementwise0_w_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* elementwise0_out = pattern->NewNode(elementwise0_out_repr()) + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("reshape2", "X") + ->AsIntermediate(); + + auto* reshape2_0 = + pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2"); + + auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr()) + ->assert_is_op_output("reshape2") + ->assert_is_op_input("transpose2") + ->AsIntermediate(); + + auto* transpose2_0 = + pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); + auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) + ->assert_is_op_output("transpose2") + ->assert_is_ops_input(matmul_ops, "X") + ->AsIntermediate(); + + auto* matmul_qk = + pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops); + auto* matmul_qk_out_var = pattern->NewNode(matmul_qk_out_repr()) + ->assert_is_ops_output(matmul_ops) + ->assert_is_op_input("scale") + ->AsIntermediate(); + + auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale"); + auto* scale_out_var = pattern->NewNode(scale_out_repr()) + ->assert_is_op_output("scale") + ->assert_is_op_input("softmax") + ->AsIntermediate(); + + auto* softmax_qk = + pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); + auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) + ->assert_is_op_output("softmax") + ->assert_is_ops_input(matmul_ops) + ->AsIntermediate(); + + auto* matmul_qkv = + pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops); + auto* matmul_qkv_out_var = pattern->NewNode(matmul_qkv_out_repr()) + ->assert_is_ops_output(matmul_ops) + ->assert_is_op_input("transpose2") + ->AsIntermediate(); + + auto* transpose2_qkv = + pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2"); + auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr()) + ->assert_is_op_output("transpose2") + ->assert_is_op_input("reshape2") + ->AsIntermediate(); + + auto* reshape2_qkv = + pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); + auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr()) + ->assert_is_op_output("reshape2") + ->AsOutput(); + + // Second path to matmul + auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_ops(mul_ops); + auto* mul1_w_var = pattern->NewNode(mul1_w_repr()) + ->AsInput() + ->assert_is_ops_input(mul_ops, "Y"); + auto* mul1_out_var = pattern->NewNode(mul1_out_repr()) + ->assert_is_ops_output(mul_ops) + ->assert_is_op_input("elementwise_add", "X") + ->AsIntermediate(); + + auto* elementwise1 = + pattern->NewNode(elementwise1_repr())->assert_is_op("elementwise_add"); + auto* elementwise1_w = pattern->NewNode(elementwise1_w_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* elementwise1_out = pattern->NewNode(elementwise1_out_repr()) + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("reshape2", "X") + ->AsIntermediate(); + + auto* reshape2_1 = + pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2"); + + auto* reshape2_1_out_var = pattern->NewNode(reshape2_1_out_repr()) + ->assert_is_op_output("reshape2") + ->assert_is_op_input("transpose2") + ->AsIntermediate(); + + auto* transpose2_1 = + pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2"); + auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) + ->assert_is_op_output("transpose2") + ->assert_is_ops_input(matmul_ops, "Y") + ->AsIntermediate(); // link to matmul qk + + // Third path to matmul + auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_ops(mul_ops); + auto* mul2_w_var = pattern->NewNode(mul2_w_repr()) + ->AsInput() + ->assert_is_ops_input(mul_ops, "Y"); + auto* mul2_out_var = pattern->NewNode(mul2_out_repr()) + ->assert_is_ops_output(mul_ops) + ->assert_is_op_input("elementwise_add", "X") + ->AsIntermediate(); + + auto* elementwise2 = + pattern->NewNode(elementwise2_repr())->assert_is_op("elementwise_add"); + auto* elementwise2_w = pattern->NewNode(elementwise2_w_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* elementwise2_out = pattern->NewNode(elementwise2_out_repr()) + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("reshape2", "X") + ->AsIntermediate(); + + auto* reshape2_2 = + pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2"); + + auto* reshape2_2_out_var = pattern->NewNode(reshape2_2_out_repr()) + ->assert_is_op_output("reshape2") + ->assert_is_op_input("transpose2") + ->AsIntermediate(); + + auto* transpose2_2 = + pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2"); + auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr()) + ->assert_is_op_output("transpose2") + ->assert_is_ops_input(matmul_ops) + ->AsIntermediate(); // link to matmul qkv + + // Q path + mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var}); + elementwise0->LinksFrom({mul0_out_var, elementwise0_w}) + .LinksTo({elementwise0_out}); + + reshape2_0->LinksFrom({elementwise0_out}).LinksTo({reshape2_0_out_var}); + transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); + // K path + mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var}); + elementwise1->LinksFrom({mul1_out_var, elementwise1_w}) + .LinksTo({elementwise1_out}); + + reshape2_1->LinksFrom({elementwise1_out}).LinksTo({reshape2_1_out_var}); + transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var}); + // compute q*k + matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var}) + .LinksTo({matmul_qk_out_var}); + scale->LinksFrom({matmul_qk_out_var}).LinksTo({scale_out_var}); + softmax_qk->LinksFrom({scale_out_var}).LinksTo({softmax_qk_out_var}); + // V path + mul2->LinksFrom({input1, mul2_w_var}).LinksTo({mul2_out_var}); + elementwise2->LinksFrom({mul2_out_var, elementwise2_w}) + .LinksTo({elementwise2_out}); + + reshape2_2->LinksFrom({elementwise2_out}).LinksTo({reshape2_2_out_var}); + transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var}); + // compute q*k*v + matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var}) + .LinksTo({matmul_qkv_out_var}); + transpose2_qkv->LinksFrom({matmul_qkv_out_var}) + .LinksTo({transpose2_qkv_out_var}); + reshape2_qkv->LinksFrom({transpose2_qkv_out_var}) + .LinksTo({reshape2_qkv_out_var}); + + return reshape2_qkv_out_var; +} + +} // namespace patterns + +int TrtQkMultiHeadMatmulFusePass::BuildQkFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + // Create pattern. + patterns::TrtQKMultiHeadMatmulPattern multihead_pattern(pattern, name_scope); + + multihead_pattern(); + auto fuse_creater = [&](Node* input0, + Node* input1, + Node* mul0, + Node* mul1, + Node* mul2, + Node* mul0_out, + Node* mul1_out, + Node* mul2_out, + Node* mul0_w, + Node* mul1_w, + Node* mul2_w, + Node* elementwise0, + Node* elementwise0_w, + Node* elementwise1, + Node* elementwise1_w, + Node* elementwise2, + Node* elementwise2_w, + Node* reshape2, + Node* reshape2_qkv_out, + Node* scale, + Node* scale_out) { + // get Device context + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CPUPlace())); + + auto scale_attr = PADDLE_GET_CONST(float, scale->Op()->GetAttr("scale")); + + // create multihead + OpDesc multihead_op_desc(mul0->Op()->Block()); + auto reshape_desc = reshape2->Op(); + int head_number = + PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) + .at(2); + multihead_op_desc.SetType("qk_multihead_matmul"); + multihead_op_desc.SetInput("Input_qk", {input0->Name()}); + multihead_op_desc.SetInput("Input_v", {input1->Name()}); + + auto* wq_tensor = + scope->FindVar(mul0_w->Name())->GetMutable(); + auto* wk_tensor = + scope->FindVar(mul1_w->Name())->GetMutable(); + auto* bq_tensor = + scope->FindVar(elementwise0_w->Name())->GetMutable(); + auto* bk_tensor = + scope->FindVar(elementwise1_w->Name())->GetMutable(); + + int hidden_out = wq_tensor->dims()[1]; + int head_size = hidden_out / head_number; + if (abs(scale_attr - 1.0f / sqrt(static_cast(head_size))) > 1e-5) { + VLOG(3) << "scale of muilthead matmul do not fit the requirement of " + "qk attention plugin, Stop fusing."; + return; + } + VLOG(3) << "trt qk attention get wq_tensor name = " << mul0_w->Name() + << "trt qk attention get wk_tensor name = " << mul1_w->Name(); + + auto* wq_data = wq_tensor->data(); + auto* wk_data = wk_tensor->data(); + auto* bq_data = bq_tensor->data(); + auto* bk_data = bk_tensor->data(); + + // combined_w_dims = [in,2,out] + auto combined_w_qk_dims = + phi::make_ddim({wq_tensor->dims()[0], 2, wq_tensor->dims()[1]}); + auto combined_bias_dims = phi::make_ddim({2, bq_tensor->dims()[0]}); + + VLOG(3) << "trt qk attention trt wq_dim in:" << wq_tensor->dims()[0] + << "trt qk attention trt wk_dim out:" << wq_tensor->dims()[1]; + auto* combined_w_qk_desc = mul0_w->Var(); + combined_w_qk_desc->SetShape( + {wq_tensor->dims()[0], 2, wq_tensor->dims()[1]}); + combined_w_qk_desc->SetPersistable(true); + phi::DenseTensor tmp_combined_w_qk_tensor; + tmp_combined_w_qk_tensor.Resize(combined_w_qk_dims); + float* tmp_combined_w_qk_data = + dev_ctx->template HostAlloc(&tmp_combined_w_qk_tensor); + + std::vector w_vec = {wq_data, wk_data}; + int dims_h = combined_w_qk_dims[0], dims_w = combined_w_qk_dims[2]; + // dims_h=in_feature, dims_w=out_feature + // Combine the two fc weights together. + // weight [Hidden_in * 2 * N * H] + for (int i = 0; i < dims_h; i++) { + for (int j = 0; j < 2; j++) { + for (int k = 0; k < dims_w; k++) { + int out_index = i * (2 * dims_w) + j * dims_w + k; + int in_index = i * dims_w + k; + tmp_combined_w_qk_data[out_index] = w_vec[j][in_index]; + } + } + } + wq_tensor->clear(); + wq_tensor->Resize(combined_w_qk_dims); + auto* new_combined_w_qk_data = dev_ctx->template HostAlloc( + wq_tensor, sizeof(float) * wq_tensor->numel()); + memcpy(new_combined_w_qk_data, + tmp_combined_w_qk_data, + sizeof(float) * wq_tensor->numel()); + + scope->EraseVars({mul1_w->Name()}); + auto* combined_bias_desc = elementwise0_w->Var(); + combined_bias_desc->SetShape({2, bq_tensor->dims()[0]}); + combined_bias_desc->SetPersistable(true); + + phi::DenseTensor tmp_combined_bias_tensor; + tmp_combined_bias_tensor.Resize(combined_bias_dims); + float* tmp_combined_bias_data = + dev_ctx->template HostAlloc(&tmp_combined_bias_tensor); + + size_t bias_size = bq_tensor->numel(); + memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size); + memcpy( + tmp_combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size); + + bq_tensor->clear(); + bq_tensor->Resize(combined_bias_dims); + auto* new_combined_bias_data = dev_ctx->template HostAlloc( + bq_tensor, sizeof(float) * bq_tensor->numel()); + + memcpy(new_combined_bias_data, + tmp_combined_bias_data, + sizeof(float) * bq_tensor->numel()); + + scope->EraseVars({elementwise1_w->Name()}); + + multihead_op_desc.SetInput("W_qk", {mul0_w->Name()}); + multihead_op_desc.SetInput("W_v", {mul2_w->Name()}); + multihead_op_desc.SetInput("B_qk", {elementwise0_w->Name()}); + multihead_op_desc.SetInput("B_v", {elementwise2_w->Name()}); + multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()}); + multihead_op_desc.SetAttr("alpha", scale_attr); + multihead_op_desc.SetAttr("head_number", head_number); + + auto* multihead = graph->CreateOpNode(&multihead_op_desc); + IR_NODE_LINK_TO(input0, multihead); + IR_NODE_LINK_TO(input1, multihead); + IR_NODE_LINK_TO(mul0_w, multihead); + IR_NODE_LINK_TO(mul2_w, multihead); + IR_NODE_LINK_TO(elementwise0_w, multihead); + IR_NODE_LINK_TO(elementwise2_w, multihead); + IR_NODE_LINK_TO(multihead, reshape2_qkv_out); + }; + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + // GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(input1, input1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(elementwise0, elementwise0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + elementwise0_w, elementwise0_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + elementwise0_out, elementwise0_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_0_out, reshape2_0_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_0_out, transpose2_0_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(elementwise1, elementwise1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + elementwise1_w, elementwise1_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + elementwise1_out, elementwise1_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_1_out, reshape2_1_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_1_out, transpose2_1_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(elementwise2, elementwise2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + elementwise2_w, elementwise2_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + elementwise2_out, elementwise2_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_2_out, reshape2_2_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_2_out, transpose2_2_out, multihead_pattern); + + // nodes need be removed + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk, matmul_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk_out, matmul_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + softmax_qk_out, softmax_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_qkv_out, matmul_qkv_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_qkv_out, reshape2_qkv_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_qkv, transpose2_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose2_qkv_out, transpose2_qkv_out, multihead_pattern); + + fuse_creater(input0, + input1, + mul0, + mul1, + mul2, + mul0_out, + mul1_out, + mul2_out, + mul0_w, + mul1_w, + mul2_w, + elementwise0, + elementwise0_w, + elementwise1, + elementwise1_w, + elementwise2, + elementwise2_w, + reshape2_0, + reshape2_qkv_out, + scale, + scale_out); + + std::unordered_set marked_nodes({reshape2_0, + reshape2_1, + reshape2_2, + reshape2_0_out, + reshape2_1_out, + reshape2_2_out, + transpose2_0, + transpose2_1, + transpose2_2, + transpose2_0_out, + transpose2_1_out, + transpose2_2_out, + matmul_qk, + matmul_qk_out, + softmax_qk, + softmax_qk_out, + transpose2_qkv, + transpose2_qkv_out, + matmul_qkv, + matmul_qkv_out, + mul0, + mul1, + mul2, + mul0_out, + mul1_out, + mul2_out, + elementwise0, + elementwise0_out, + elementwise1, + elementwise1_w, + elementwise1_out, + elementwise2, + elementwise2_out, + reshape2_qkv, + scale}); + // Remove unneeded nodes. + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + gpd(graph, handler); + + return fusion_count; +} + +void TrtQkMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); +#ifdef PADDLE_WITH_TENSORRT + auto trt_version = paddle::inference::tensorrt::GetTrtRuntimeVersion(); + if (std::get<0>(trt_version) * 1000 + std::get<1>(trt_version) * 100 + + std::get<2>(trt_version) * 10 < + 8520) { + VLOG(3) << "Qk attention oss plugin only available for trt version >= " + "8.5.2.2. Stop this pass"; + return; + } +#else + // if no tensorrt, early stop + return; +#endif + bool with_dynamic_shape = Get("with_dynamic_shape"); + if (!with_dynamic_shape) { + VLOG(3) << "Qk attention oss plugin need trt " + "with_dynamic_shape. Stop this pass"; + return; + } + auto* scope = param_scope(); + int fusion_count = BuildQkFusion(graph, name_scope_, scope); + AddStatis(fusion_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(trt_qk_multihead_matmul_fuse_pass, + paddle::framework::ir::TrtQkMultiHeadMatmulFusePass); +REGISTER_PASS_CAPABILITY(trt_qk_multihead_matmul_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("reshape2", 0) + .EQ("transpose2", 0) + .EQ("scale", 0) + .EQ("softmax", 0) + .EQ("matmul_v2", 0)); diff --git a/paddle/fluid/framework/ir/trt_qk_multihead_matmul_fuse_pass.h b/paddle/fluid/framework/ir/trt_qk_multihead_matmul_fuse_pass.h new file mode 100644 index 00000000000..abc0d63e140 --- /dev/null +++ b/paddle/fluid/framework/ir/trt_qk_multihead_matmul_fuse_pass.h @@ -0,0 +1,104 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct TrtQKMultiHeadMatmulPattern : public PatternBase { + TrtQKMultiHeadMatmulPattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "qk_multihead_matmul") {} + + PDNode* operator()(); + + // declare operator node's name + PATTERN_DECL_NODE(input0); + PATTERN_DECL_NODE(input1); + PATTERN_DECL_NODE(mul0); + PATTERN_DECL_NODE(mul1); + PATTERN_DECL_NODE(mul2); + PATTERN_DECL_NODE(mul0_w); + PATTERN_DECL_NODE(mul1_w); + PATTERN_DECL_NODE(mul2_w); + PATTERN_DECL_NODE(mul0_out); + PATTERN_DECL_NODE(mul1_out); + PATTERN_DECL_NODE(mul2_out); + + PATTERN_DECL_NODE(elementwise0); + PATTERN_DECL_NODE(elementwise1); + PATTERN_DECL_NODE(elementwise2); + + PATTERN_DECL_NODE(elementwise0_w); + PATTERN_DECL_NODE(elementwise1_w); + PATTERN_DECL_NODE(elementwise2_w); + + PATTERN_DECL_NODE(elementwise0_out); + PATTERN_DECL_NODE(elementwise1_out); + PATTERN_DECL_NODE(elementwise2_out); + + PATTERN_DECL_NODE(scale); + PATTERN_DECL_NODE(scale_out); + PATTERN_DECL_NODE(reshape2_0); + PATTERN_DECL_NODE(reshape2_1); + PATTERN_DECL_NODE(reshape2_2); + PATTERN_DECL_NODE(reshape2_qkv); + PATTERN_DECL_NODE(reshape2_0_out); + PATTERN_DECL_NODE(reshape2_1_out); + PATTERN_DECL_NODE(reshape2_2_out); + PATTERN_DECL_NODE(reshape2_qkv_out); + PATTERN_DECL_NODE(transpose2_0); + PATTERN_DECL_NODE(transpose2_1); + PATTERN_DECL_NODE(transpose2_2); + PATTERN_DECL_NODE(transpose2_qkv); + PATTERN_DECL_NODE(transpose2_0_out); + PATTERN_DECL_NODE(transpose2_1_out); + PATTERN_DECL_NODE(transpose2_2_out); + PATTERN_DECL_NODE(transpose2_qkv_out); + PATTERN_DECL_NODE(matmul_qk); + PATTERN_DECL_NODE(matmul_qk_out); + PATTERN_DECL_NODE(softmax_qk); + PATTERN_DECL_NODE(softmax_qk_out); + + PATTERN_DECL_NODE(matmul_qkv); + PATTERN_DECL_NODE(matmul_qkv_out); +}; + +} // namespace patterns + +class TrtQkMultiHeadMatmulFusePass : public FusePassBase { + public: + virtual ~TrtQkMultiHeadMatmulFusePass() {} + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{"trt_qk_multihead_matmul_fuse"}; + + private: + int BuildQkFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index b4c35e82c6e..b07c47b81ef 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2569,6 +2569,7 @@ USE_TRT_CONVERTER(preln_groupnorm_act) #if IS_TRT_VERSION_GE(8522) USE_TRT_CONVERTER(flash_multihead_matmul) USE_TRT_CONVERTER(cross_multihead_matmul) +USE_TRT_CONVERTER(qk_multihead_matmul) #endif #if IS_TRT_VERSION_GE(8510) USE_TRT_CONVERTER(grid_sampler) diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 8b1399515ed..3cc8b077ad7 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -108,6 +108,7 @@ const std::vector kTRTSubgraphPasses({ "trt_flash_multihead_matmul_fuse_pass", // "trt_cross_multihead_matmul_fuse_pass", // "vit_attention_fuse_pass", // + "trt_qk_multihead_matmul_fuse_pass", // "layernorm_shift_partition_fuse_pass", // "merge_layernorm_fuse_pass", // #if !defined _WIN32 diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index a47267ac3a5..cbe26a3d31e 100755 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -28,6 +28,7 @@ list( multihead_matmul_roformer_op.cc flash_multihead_matmul_op.cc cross_multihead_matmul_op.cc + qk_multihead_matmul_op.cc grid_sampler_op.cc shuffle_channel_op.cc fill_any_like_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/qk_multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/qk_multihead_matmul_op.cc new file mode 100644 index 00000000000..89b65e95bd8 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/qk_multihead_matmul_op.cc @@ -0,0 +1,301 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See +the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class QkMultiheadMatMulOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(3) << "convert a qk_multihead_mamul op to a corresponding tensorrt " + "network structure"; + + framework::OpDesc op_desc(op, nullptr); + auto* input_qk = engine_->GetITensor(op_desc.Input("Input_qk").front()); + auto* input_v = engine_->GetITensor(op_desc.Input("Input_v").front()); + + auto output_name = op_desc.Output("Out")[0]; + + /* ------------------ weight_qk -------------------------*/ + auto weight_qk_name = op_desc.Input("W_qk").front(); + auto* weight_qk_v = scope.FindVar(weight_qk_name); + auto* weight_qk_t = weight_qk_v->GetMutable(); + float* weight_qk_data = nullptr; + weight_qk_data = const_cast(static_cast( + engine_->GetFp32TrtWeight(weight_qk_name, *weight_qk_t).get().values)); + + const auto& weight_qk_dims = + weight_qk_t->dims(); // hidden_in_qk 2 hidden_out_qk + int hidden_in_qk = weight_qk_dims[0]; + int num_qk = weight_qk_dims[1]; + int hidden_out_qk = weight_qk_dims[2]; + int head_number_qk = PADDLE_GET_CONST(int, op_desc.GetAttr("head_number")); + int head_size_qk = hidden_out_qk / head_number_qk; + int n_qk = num_qk * hidden_out_qk; + + // [hidden_in, 2, head_number, head_size] + // -> [head_number, 2, head_size, hidden_in] + auto transpose_weight_qk = [](const float* src, + float* dst, + int two, + int head_number, + int head_size, + int hidden_in) { + for (int hn = 0; hn < head_number; hn++) { + for (int t = 0; t < two; t++) { + for (int hs = 0; hs < head_size; hs++) { + for (int hi = 0; hi < hidden_in; hi++) { + int out_index = hn * two * head_size * hidden_in + + t * head_size * hidden_in + hs * hidden_in + hi; + int in_index = hi * two * head_number * head_size + + t * head_number * head_size + hn * head_size + hs; + dst[out_index] = src[in_index]; + } + } + } + } + }; + + std::vector weight_qk_data_tmp; + weight_qk_data_tmp.reserve(weight_qk_t->numel()); + memcpy(weight_qk_data_tmp.data(), + weight_qk_data, + weight_qk_t->numel() * sizeof(float)); + transpose_weight_qk(weight_qk_data_tmp.data(), + weight_qk_data, + num_qk, + head_number_qk, + head_size_qk, + hidden_in_qk); + + /* ------------------ bias_qk -------------------------*/ + auto bias_qk_name = op_desc.Input("B_qk").front(); + auto* bias_qk_v = scope.FindVar(bias_qk_name); + auto* bias_qk_t = bias_qk_v->GetMutable(); + float* bias_qk_data = nullptr; + bias_qk_data = const_cast(static_cast( + engine_->GetFp32TrtWeight(bias_qk_name, *bias_qk_t).get().values)); + + // [2, head_number, head_size] -> [head_number, 2, head_size] + auto transpose_bias_qk = [](const float* src, float* dst, int N, int H) { + for (int i = 0; i < 2; ++i) { + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H; ++h) { + dst[n * 2 * H + i * H + h] = src[i * N * H + n * H + h]; + } + } + } + }; + + std::vector bias_qk_data_tmp; + bias_qk_data_tmp.reserve(bias_qk_t->numel()); + memcpy(bias_qk_data_tmp.data(), + bias_qk_data, + bias_qk_t->numel() * sizeof(float)); + transpose_bias_qk( + bias_qk_data_tmp.data(), bias_qk_data, head_number_qk, head_size_qk); + + auto weight_qk_shape = nvinfer1::Dims3{1, n_qk, hidden_in_qk}; + auto* weight_qk_tensor = + AddConstantLayer(weight_qk_data, weight_qk_shape, " "); + auto bias_qk_shape = nvinfer1::Dims3{1, 1, n_qk}; + auto* bias_qk_tensor = AddConstantLayer(bias_qk_data, bias_qk_shape, " "); + nvinfer1::ITensor* input_qk_shape_tensor = Shape(input_qk); + + nvinfer1::ILayer* fc_qk_layer = nullptr; + nvinfer1::ILayer* merge_qk_element_layer = nullptr; + nvinfer1::MatrixOperation matrix_operation_X = + nvinfer1::MatrixOperation::kNONE; + nvinfer1::MatrixOperation matrix_operation_Y = + nvinfer1::MatrixOperation::kTRANSPOSE; + fc_qk_layer = TRT_ENGINE_ADD_LAYER(engine_, + MatrixMultiply, + *input_qk, + matrix_operation_X, + *weight_qk_tensor, + matrix_operation_Y); + fc_qk_layer->setName( + ("qk_attention_matrix_multiply(Output: " + output_name + ")").c_str()); + + // add qk ElementWiseLayer layer + nvinfer1::ElementWiseOperation elementwise_operation = + nvinfer1::ElementWiseOperation::kSUM; + merge_qk_element_layer = TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *fc_qk_layer->getOutput(0), + *bias_qk_tensor, + elementwise_operation); + merge_qk_element_layer->setName( + ("multihead_mamul_fc_qk(Output: " + output_name + ")").c_str()); + + auto* reshape_after_fc_qk_layer = TRT_ENGINE_ADD_LAYER( + engine_, Shuffle, *merge_qk_element_layer->getOutput(0)); + std::vector mha_input_qk_tensor_shape; + for (int i = 0; i < 5; i++) { + mha_input_qk_tensor_shape.push_back(Add1DConstantLayer(1)); + } + mha_input_qk_tensor_shape[0] = + GetEleTensorOfShape(input_qk_shape_tensor, 0); + mha_input_qk_tensor_shape[1] = + GetEleTensorOfShape(input_qk_shape_tensor, 1); + mha_input_qk_tensor_shape[2] = Add1DConstantLayer(head_number_qk); + mha_input_qk_tensor_shape[3] = Add1DConstantLayer(2); + mha_input_qk_tensor_shape[4] = Add1DConstantLayer(head_size_qk); + reshape_after_fc_qk_layer->setInput(1, *Concat(mha_input_qk_tensor_shape)); + reshape_after_fc_qk_layer->setName( + ("shuffle_after_fc_qk_multihead_matmul(Output: " + output_name + ")") + .c_str()); + + /* ------------------ weight_v -------------------------*/ + auto weight_v_name = op_desc.Input("W_v").front(); + auto* weight_v_v = scope.FindVar(weight_v_name); + auto* weight_v_t = weight_v_v->GetMutable(); + float* weight_v_data = nullptr; + weight_v_data = const_cast(static_cast( + engine_->GetFp32TrtWeight(weight_v_name, *weight_v_t).get().values)); + int n_v = hidden_out_qk; + + // [hidden_in, head_number, head_size] + // -> [head_number, head_size, hidden_in] + auto transpose_weight_v = [](const float* src, + float* dst, + int head_number, + int head_size, + int hidden_in) { + for (int hn = 0; hn < head_number; hn++) { + for (int hs = 0; hs < head_size; hs++) { + for (int hi = 0; hi < hidden_in; hi++) { + int out_index = hn * head_size * hidden_in + hs * hidden_in + hi; + int in_index = hi * head_number * head_size + hn * head_size + hs; + dst[out_index] = src[in_index]; + } + } + } + }; + std::vector weight_v_data_tmp; + weight_v_data_tmp.reserve(weight_v_t->numel()); + memcpy(weight_v_data_tmp.data(), + weight_v_data, + weight_v_t->numel() * sizeof(float)); + transpose_weight_v(weight_v_data_tmp.data(), + weight_v_data, + head_number_qk, + head_size_qk, + hidden_in_qk); + + /* ------------------ bias_v -------------------------*/ + auto bias_v_name = op_desc.Input("B_v").front(); + auto* bias_v_v = scope.FindVar(bias_v_name); + auto* bias_v_t = bias_v_v->GetMutable(); + float* bias_v_data = nullptr; + bias_v_data = const_cast(static_cast( + engine_->GetFp32TrtWeight(bias_v_name, *bias_v_t).get().values)); + + auto weight_v_shape = nvinfer1::Dims3{1, n_v, hidden_in_qk}; + auto* weight_v_tensor = + AddConstantLayer(weight_v_data, weight_v_shape, " "); + auto bias_v_shape = nvinfer1::Dims3{1, 1, n_v}; + auto* bias_v_tensor = AddConstantLayer(bias_v_data, bias_v_shape, " "); + nvinfer1::ITensor* input_v_shape_tensor = Shape(input_v); + + nvinfer1::ILayer* fc_v_layer = nullptr; + nvinfer1::ILayer* merge_v_element_layer = nullptr; + fc_v_layer = TRT_ENGINE_ADD_LAYER(engine_, + MatrixMultiply, + *input_v, + matrix_operation_X, + *weight_v_tensor, + matrix_operation_Y); + fc_v_layer->setName( + ("v_attention_matrix_multiply(Output: " + output_name + ")").c_str()); + + // add v ElementWiseLayer layer + merge_v_element_layer = TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *fc_v_layer->getOutput(0), + *bias_v_tensor, + elementwise_operation); + merge_v_element_layer->setName( + ("multihead_mamul_fc_v(Output: " + output_name + ")").c_str()); + + // add shuffle for fc layer + auto* reshape_after_fc_v_layer = TRT_ENGINE_ADD_LAYER( + engine_, Shuffle, *merge_v_element_layer->getOutput(0)); + std::vector mha_input_v_tensor_shape; + for (int i = 0; i < 5; i++) { + mha_input_v_tensor_shape.push_back(Add1DConstantLayer(1)); + } + mha_input_v_tensor_shape[0] = GetEleTensorOfShape(input_v_shape_tensor, 0); + mha_input_v_tensor_shape[1] = GetEleTensorOfShape(input_v_shape_tensor, 1); + mha_input_v_tensor_shape[2] = Add1DConstantLayer(head_number_qk); + mha_input_v_tensor_shape[3] = Add1DConstantLayer(1); + mha_input_v_tensor_shape[4] = Add1DConstantLayer(head_size_qk); + reshape_after_fc_v_layer->setInput(1, *Concat(mha_input_v_tensor_shape)); + reshape_after_fc_v_layer->setName( + ("shuffle_after_fc_v_multihead_matmul(Output: " + output_name + ")") + .c_str()); + + std::vector mha_input_tensor_vector{ + reshape_after_fc_qk_layer->getOutput(0), + reshape_after_fc_v_layer->getOutput(0)}; + nvinfer1::ITensor* mha_input_tensor = Concat(mha_input_tensor_vector, 3); + auto creator = GetPluginRegistry()->getPluginCreator("fMHA_V2", "1"); + assert(creator != nullptr); + std::vector fields{}; + nvinfer1::PluginFieldCollection* plugin_collection = + static_cast( + malloc(sizeof(*plugin_collection) + + fields.size() * + sizeof(nvinfer1::PluginField))); // remember to free + + plugin_collection->nbFields = static_cast(fields.size()); + plugin_collection->fields = fields.data(); + auto plugin = creator->createPlugin("fMHA_V2", plugin_collection); + free(plugin_collection); + std::vector plugin_inputs; + plugin_inputs.emplace_back(mha_input_tensor); + auto plugin_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + + // add shuffle + nvinfer1::ITensor* batch_tensor = + GetEleTensorOfShape(input_qk_shape_tensor, 0); + nvinfer1::ITensor* length_tensor = + GetEleTensorOfShape(input_qk_shape_tensor, 1); + auto* reshape_after_mha_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0)); + std::vector reshape_tensor; + reshape_tensor.push_back(batch_tensor); + reshape_tensor.push_back(length_tensor); + reshape_tensor.push_back(Add1DConstantLayer(-1)); + reshape_after_mha_layer->setInput(1, *Concat(reshape_tensor)); + reshape_after_mha_layer->setName( + ("shuffle_last_multihead_matmul(Output: " + output_name + ")").c_str()); + nvinfer1::ILayer* layer = nullptr; + layer = reshape_after_mha_layer; + RreplenishLayerAndOutput( + layer, "qk_multihead_matmul", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(qk_multihead_matmul, QkMultiheadMatMulOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 685fc44d7b3..24dca82d3fb 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -70,6 +70,8 @@ struct SimpleOpTypeSetTeller : public Teller { int8_teller_set.insert("flash_multihead_matmul"); teller_set.insert("cross_multihead_matmul"); int8_teller_set.insert("cross_multihead_matmul"); + teller_set.insert("qk_multihead_matmul"); + int8_teller_set.insert("qk_multihead_matmul"); #endif #if IS_TRT_VERSION_GE(8200) teller_set.insert("round"); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_qk_multihead_matmul.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_qk_multihead_matmul.py new file mode 100644 index 00000000000..548f0486e12 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_qk_multihead_matmul.py @@ -0,0 +1,385 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import List + +import numpy as np +from program_config import ProgramConfig, TensorConfig +from trt_layer_auto_scan_test import SkipReasons, TrtLayerAutoScanTest + +import paddle.inference as paddle_infer + + +class TrtConvertQkAttentionTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + ver = paddle_infer.get_trt_compile_version() + if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8520: + return False + return True + + def sample_program_configs(self): + def generate_input1(batch, length): + return np.random.rand(batch, length, 256).astype(np.float32) / 10 + + def generate_input2(batch, length): + return np.random.rand(batch, length, 256).astype(np.float32) / 10 + + def generate_weight_q(): + return np.random.rand(256, 256).astype(np.float32) / 10 + + def generate_weight_k(): + return np.random.rand(256, 256).astype(np.float32) / 10 + + def generate_weight_v(): + return np.random.rand(256, 256).astype(np.float32) / 10 + + def generate_bias_q(): + return np.random.rand(256).astype(np.float32) / 10 + + def generate_bias_k(): + return np.random.rand(256).astype(np.float32) / 10 + + def generate_bias_v(): + return np.random.rand(256).astype(np.float32) / 10 + + for batch in [1, 2]: + self.batch = batch + for length in [300, 400]: + ops_config = [ + # q + { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["input_data1"], + "Y": ["matmul_q_weight"], + }, + "op_outputs": {"Out": ["matmul_q_output"]}, + "op_attrs": {"trans_x": False, "trans_y": False}, + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["matmul_q_output"], + "Y": ["bias_q"], + }, + "op_outputs": {"Out": ["elementwise_q_output"]}, + "op_attrs": { + "Scale_out": 1.0, + "Scale_x": 1.0, + "Scale_y": 1.0, + "axis": 2, + }, + }, + { + "op_type": "reshape2", + "op_inputs": { + "X": ["elementwise_q_output"], + }, + "op_outputs": { + "Out": ["reshape_q_output"], + "XShape": ["reshape_q_output_xshape"], + }, + "op_attrs": {"shape": [0, 0, 8, 32]}, + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["reshape_q_output"]}, + "op_outputs": { + "Out": ["transpose_q_output"], + "XShape": ["transpose_q_output_xshape"], + }, + "op_attrs": { + "axis": [0, 2, 1, 3], + "data_format": "AnyLayout", + }, + }, + # k + { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["input_data1"], + "Y": ["matmul_k_weight"], + }, + "op_outputs": {"Out": ["matmul_k_output"]}, + "op_attrs": {"trans_x": False, "trans_y": False}, + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["matmul_k_output"], + "Y": ["bias_k"], + }, + "op_outputs": {"Out": ["elementwise_k_output"]}, + "op_attrs": { + "Scale_out": 1.0, + "Scale_x": 1.0, + "Scale_y": 1.0, + "axis": 2, + }, + }, + { + "op_type": "reshape2", + "op_inputs": { + "X": ["elementwise_k_output"], + }, + "op_outputs": { + "Out": ["reshape_k_output"], + "XShape": ["reshape_k_output_xshape"], + }, + "op_attrs": {"shape": [0, 0, 8, 32]}, + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["reshape_k_output"]}, + "op_outputs": { + "Out": ["transpose_k_output"], + "XShape": ["transpose_k_output_xshape"], + }, + "op_attrs": { + "axis": [0, 2, 1, 3], + "data_format": "AnyLayout", + }, + }, + # V + { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["input_data2"], + "Y": ["matmul_v_weight"], + }, + "op_outputs": {"Out": ["matmul_v_output"]}, + "op_attrs": {"trans_x": False, "trans_y": False}, + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["matmul_v_output"], + "Y": ["bias_v"], + }, + "op_outputs": {"Out": ["elementwise_v_output"]}, + "op_attrs": { + "Scale_out": 1.0, + "Scale_x": 1.0, + "Scale_y": 1.0, + "axis": 2, + }, + }, + { + "op_type": "reshape2", + "op_inputs": { + "X": ["elementwise_v_output"], + }, + "op_outputs": { + "Out": ["reshape_v_output"], + "XShape": ["reshape_v_output_xshape"], + }, + "op_attrs": {"shape": [0, 0, 8, 32]}, + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["reshape_v_output"]}, + "op_outputs": { + "Out": ["transpose_v_output"], + "XShape": ["transpose_v_output_xshape"], + }, + "op_attrs": { + "axis": [0, 2, 1, 3], + "data_format": "AnyLayout", + }, + }, + # matmul1+matmul2 + { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["transpose_q_output"], + "Y": ["transpose_k_output"], + }, + "op_outputs": {"Out": ["matmul1_output"]}, + "op_attrs": {"trans_x": False, "trans_y": True}, + }, + { + "op_type": "scale", + "op_inputs": { + "X": ["matmul1_output"], + }, + "op_outputs": {"Out": ["scale_output"]}, + "op_attrs": { + "scale": 0.17677, + "bias": 0.0, + "bias_after_scale": True, + }, + }, + { + "op_type": "softmax", + "op_inputs": {"X": ["scale_output"]}, + "op_outputs": {"Out": ["softmax_output"]}, + "op_attrs": { + "axis": -1, + "data_format": "AnyLayout", + }, + }, + { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["softmax_output"], + "Y": ["transpose_v_output"], + }, + "op_outputs": {"Out": ["matmul2_output"]}, + "op_attrs": {"trans_x": False, "trans_y": False}, + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["matmul2_output"]}, + "op_outputs": { + "Out": ["transpose_output"], + "XShape": ["transpose_output_xshape"], + }, + "op_attrs": { + "axis": [0, 2, 1, 3], + "data_format": "AnyLayout", + }, + }, + { + "op_type": "reshape2", + "op_inputs": {"X": ["transpose_output"]}, + "op_outputs": { + "Out": ["reshape_output"], + "XShape": ["reshape_output_xshape"], + }, + "op_attrs": {"shape": [0, 0, 256]}, + }, + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "matmul_q_weight": TensorConfig( + data_gen=partial(generate_weight_q) + ), + "matmul_k_weight": TensorConfig( + data_gen=partial(generate_weight_k) + ), + "matmul_v_weight": TensorConfig( + data_gen=partial(generate_weight_v) + ), + "bias_q": TensorConfig( + data_gen=partial(generate_bias_q) + ), + "bias_k": TensorConfig( + data_gen=partial(generate_bias_k) + ), + "bias_v": TensorConfig( + data_gen=partial(generate_bias_v) + ), + }, + inputs={ + "input_data1": TensorConfig( + data_gen=partial(generate_input1, batch, length) + ), + "input_data2": TensorConfig( + data_gen=partial(generate_input2, batch, length) + ), + }, + outputs=["reshape_output"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + # The last dim of input1 and input2 should be static. + self.dynamic_shape.min_input_shape = { + "input_data1": [1, 300, 256], + "input_data2": [1, 300, 256], + } + self.dynamic_shape.max_input_shape = { + "input_data1": [4, 1200, 256], + "input_data2": [4, 1200, 256], + } + self.dynamic_shape.opt_input_shape = { + "input_data1": [1, 300, 256], + "input_data2": [1, 300, 256], + } + + def clear_dynamic_shape(): + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + self.trt_param.workspace_size = 2013265920 + yield self.create_inference_config(), (1, 3), (1e-5, 1e-5) + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), (1, 3), (1e-3, 1e-3) + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + self.trt_param.workspace_size = 2013265920 + yield self.create_inference_config(), (1, 3), (1e-5, 1e-4) + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), (1, 3), (1e-2, 1e-3) + + def add_skip_trt_case(self): + def teller1(program_config, predictor_config): + if self.dynamic_shape.min_input_shape == {}: + return True + return False + + self.add_skip_case( + teller1, + SkipReasons.TRT_NOT_IMPLEMENTED, + "The qk attention trt oss plugin do not support static shape yet", + ) + + def teller2(program_config, predictor_config): + if self.trt_param.precision == paddle_infer.PrecisionType.Float32: + return True + return False + + self.add_skip_case( + teller2, + SkipReasons.TRT_NOT_IMPLEMENTED, + "The qk attention trt oss plugin do not support fp32 yet", + ) + + def teller3(program_config, predictor_config): + if self.trt_param.precision == paddle_infer.PrecisionType.Int8: + return True + return False + + self.add_skip_case( + teller3, + SkipReasons.TRT_NOT_IMPLEMENTED, + "The qk attention trt oss plugin do not support int8 yet.", + ) + + def test(self): + self.add_skip_trt_case() + self.run_test() + + +if __name__ == "__main__": + unittest.main() -- GitLab