未验证 提交 6934ac79 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Support two inputs of multihead attention named qk_multihead. (#52455)

* Support two inputs of multihead attention named qk_multihead
上级 01247e33
......@@ -134,6 +134,7 @@ if(WITH_TENSORRT)
pass_library(trt_multihead_matmul_fuse_pass inference)
pass_library(trt_flash_multihead_matmul_fuse_pass inference)
pass_library(trt_cross_multihead_matmul_fuse_pass inference)
pass_library(trt_qk_multihead_matmul_fuse_pass inference)
pass_library(trt_skip_layernorm_fuse_pass inference)
pass_library(merge_layernorm_fuse_pass inference)
pass_library(preln_skip_layernorm_fuse_pass inference)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/trt_qk_multihead_matmul_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#ifdef PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/helper.h"
#endif
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
// input_qk input_v
// |q |k v
// |------| |
// matmul matmul matmul
// | | |
// reshape reshape reshape
// | | |
// trans trans trans
// |(x) |(x) |
// matmul |
// | |
// scale |
// | |
// softmax |(y)
// |------matmul
// |
// trans
// |
// reshape
// |
// output
//
// -> fused to
//
// input_qk intput_v
// |
// qk_multihead_matmul
// |
// output
PDNode* TrtQKMultiHeadMatmulPattern::operator()() {
std::unordered_set<std::string> mul_ops{"mul", "matmul_v2"};
std::unordered_set<std::string> matmul_ops{"matmul", "matmul_v2"};
auto* input0 = pattern->NewNode(input0_repr());
auto* input1 = pattern->NewNode(input1_repr());
input0->assert_is_ops_input(mul_ops);
input1->assert_is_ops_input(mul_ops);
VLOG(5) << "Start match TrtQKMultiHeadMatmulPattern";
// First path
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_ops(mul_ops);
auto* mul0_w_var = pattern->NewNode(mul0_w_repr())
->AsInput()
->assert_is_ops_input(mul_ops, "Y");
auto* mul0_out_var = pattern->NewNode(mul0_out_repr())
->assert_is_ops_output(mul_ops)
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto* elementwise0 =
pattern->NewNode(elementwise0_repr())->assert_is_op("elementwise_add");
auto* elementwise0_w = pattern->NewNode(elementwise0_w_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* elementwise0_out = pattern->NewNode(elementwise0_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("reshape2", "X")
->AsIntermediate();
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var = pattern->NewNode(reshape2_0_out_repr())
->assert_is_op_output("reshape2")
->assert_is_op_input("transpose2")
->AsIntermediate();
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2")
->assert_is_ops_input(matmul_ops, "X")
->AsIntermediate();
auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops);
auto* matmul_qk_out_var = pattern->NewNode(matmul_qk_out_repr())
->assert_is_ops_output(matmul_ops)
->assert_is_op_input("scale")
->AsIntermediate();
auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale");
auto* scale_out_var = pattern->NewNode(scale_out_repr())
->assert_is_op_output("scale")
->assert_is_op_input("softmax")
->AsIntermediate();
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax")
->assert_is_ops_input(matmul_ops)
->AsIntermediate();
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops);
auto* matmul_qkv_out_var = pattern->NewNode(matmul_qkv_out_repr())
->assert_is_ops_output(matmul_ops)
->assert_is_op_input("transpose2")
->AsIntermediate();
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2")
->assert_is_op_input("reshape2")
->AsIntermediate();
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2")
->AsOutput();
// Second path to matmul
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_ops(mul_ops);
auto* mul1_w_var = pattern->NewNode(mul1_w_repr())
->AsInput()
->assert_is_ops_input(mul_ops, "Y");
auto* mul1_out_var = pattern->NewNode(mul1_out_repr())
->assert_is_ops_output(mul_ops)
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto* elementwise1 =
pattern->NewNode(elementwise1_repr())->assert_is_op("elementwise_add");
auto* elementwise1_w = pattern->NewNode(elementwise1_w_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* elementwise1_out = pattern->NewNode(elementwise1_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("reshape2", "X")
->AsIntermediate();
auto* reshape2_1 =
pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2");
auto* reshape2_1_out_var = pattern->NewNode(reshape2_1_out_repr())
->assert_is_op_output("reshape2")
->assert_is_op_input("transpose2")
->AsIntermediate();
auto* transpose2_1 =
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2")
->assert_is_ops_input(matmul_ops, "Y")
->AsIntermediate(); // link to matmul qk
// Third path to matmul
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_ops(mul_ops);
auto* mul2_w_var = pattern->NewNode(mul2_w_repr())
->AsInput()
->assert_is_ops_input(mul_ops, "Y");
auto* mul2_out_var = pattern->NewNode(mul2_out_repr())
->assert_is_ops_output(mul_ops)
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto* elementwise2 =
pattern->NewNode(elementwise2_repr())->assert_is_op("elementwise_add");
auto* elementwise2_w = pattern->NewNode(elementwise2_w_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* elementwise2_out = pattern->NewNode(elementwise2_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("reshape2", "X")
->AsIntermediate();
auto* reshape2_2 =
pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2");
auto* reshape2_2_out_var = pattern->NewNode(reshape2_2_out_repr())
->assert_is_op_output("reshape2")
->assert_is_op_input("transpose2")
->AsIntermediate();
auto* transpose2_2 =
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2")
->assert_is_ops_input(matmul_ops)
->AsIntermediate(); // link to matmul qkv
// Q path
mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var});
elementwise0->LinksFrom({mul0_out_var, elementwise0_w})
.LinksTo({elementwise0_out});
reshape2_0->LinksFrom({elementwise0_out}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
// K path
mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var});
elementwise1->LinksFrom({mul1_out_var, elementwise1_w})
.LinksTo({elementwise1_out});
reshape2_1->LinksFrom({elementwise1_out}).LinksTo({reshape2_1_out_var});
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
// compute q*k
matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var})
.LinksTo({matmul_qk_out_var});
scale->LinksFrom({matmul_qk_out_var}).LinksTo({scale_out_var});
softmax_qk->LinksFrom({scale_out_var}).LinksTo({softmax_qk_out_var});
// V path
mul2->LinksFrom({input1, mul2_w_var}).LinksTo({mul2_out_var});
elementwise2->LinksFrom({mul2_out_var, elementwise2_w})
.LinksTo({elementwise2_out});
reshape2_2->LinksFrom({elementwise2_out}).LinksTo({reshape2_2_out_var});
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
// compute q*k*v
matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
return reshape2_qkv_out_var;
}
} // namespace patterns
int TrtQkMultiHeadMatmulFusePass::BuildQkFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::TrtQKMultiHeadMatmulPattern multihead_pattern(pattern, name_scope);
multihead_pattern();
auto fuse_creater = [&](Node* input0,
Node* input1,
Node* mul0,
Node* mul1,
Node* mul2,
Node* mul0_out,
Node* mul1_out,
Node* mul2_out,
Node* mul0_w,
Node* mul1_w,
Node* mul2_w,
Node* elementwise0,
Node* elementwise0_w,
Node* elementwise1,
Node* elementwise1_w,
Node* elementwise2,
Node* elementwise2_w,
Node* reshape2,
Node* reshape2_qkv_out,
Node* scale,
Node* scale_out) {
// get Device context
auto* dev_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()));
auto scale_attr = PADDLE_GET_CONST(float, scale->Op()->GetAttr("scale"));
// create multihead
OpDesc multihead_op_desc(mul0->Op()->Block());
auto reshape_desc = reshape2->Op();
int head_number =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(2);
multihead_op_desc.SetType("qk_multihead_matmul");
multihead_op_desc.SetInput("Input_qk", {input0->Name()});
multihead_op_desc.SetInput("Input_v", {input1->Name()});
auto* wq_tensor =
scope->FindVar(mul0_w->Name())->GetMutable<phi::DenseTensor>();
auto* wk_tensor =
scope->FindVar(mul1_w->Name())->GetMutable<phi::DenseTensor>();
auto* bq_tensor =
scope->FindVar(elementwise0_w->Name())->GetMutable<phi::DenseTensor>();
auto* bk_tensor =
scope->FindVar(elementwise1_w->Name())->GetMutable<phi::DenseTensor>();
int hidden_out = wq_tensor->dims()[1];
int head_size = hidden_out / head_number;
if (abs(scale_attr - 1.0f / sqrt(static_cast<float>(head_size))) > 1e-5) {
VLOG(3) << "scale of muilthead matmul do not fit the requirement of "
"qk attention plugin, Stop fusing.";
return;
}
VLOG(3) << "trt qk attention get wq_tensor name = " << mul0_w->Name()
<< "trt qk attention get wk_tensor name = " << mul1_w->Name();
auto* wq_data = wq_tensor->data<float>();
auto* wk_data = wk_tensor->data<float>();
auto* bq_data = bq_tensor->data<float>();
auto* bk_data = bk_tensor->data<float>();
// combined_w_dims = [in,2,out]
auto combined_w_qk_dims =
phi::make_ddim({wq_tensor->dims()[0], 2, wq_tensor->dims()[1]});
auto combined_bias_dims = phi::make_ddim({2, bq_tensor->dims()[0]});
VLOG(3) << "trt qk attention trt wq_dim in:" << wq_tensor->dims()[0]
<< "trt qk attention trt wk_dim out:" << wq_tensor->dims()[1];
auto* combined_w_qk_desc = mul0_w->Var();
combined_w_qk_desc->SetShape(
{wq_tensor->dims()[0], 2, wq_tensor->dims()[1]});
combined_w_qk_desc->SetPersistable(true);
phi::DenseTensor tmp_combined_w_qk_tensor;
tmp_combined_w_qk_tensor.Resize(combined_w_qk_dims);
float* tmp_combined_w_qk_data =
dev_ctx->template HostAlloc<float>(&tmp_combined_w_qk_tensor);
std::vector<float*> w_vec = {wq_data, wk_data};
int dims_h = combined_w_qk_dims[0], dims_w = combined_w_qk_dims[2];
// dims_h=in_feature, dims_w=out_feature
// Combine the two fc weights together.
// weight [Hidden_in * 2 * N * H]
for (int i = 0; i < dims_h; i++) {
for (int j = 0; j < 2; j++) {
for (int k = 0; k < dims_w; k++) {
int out_index = i * (2 * dims_w) + j * dims_w + k;
int in_index = i * dims_w + k;
tmp_combined_w_qk_data[out_index] = w_vec[j][in_index];
}
}
}
wq_tensor->clear();
wq_tensor->Resize(combined_w_qk_dims);
auto* new_combined_w_qk_data = dev_ctx->template HostAlloc<float>(
wq_tensor, sizeof(float) * wq_tensor->numel());
memcpy(new_combined_w_qk_data,
tmp_combined_w_qk_data,
sizeof(float) * wq_tensor->numel());
scope->EraseVars({mul1_w->Name()});
auto* combined_bias_desc = elementwise0_w->Var();
combined_bias_desc->SetShape({2, bq_tensor->dims()[0]});
combined_bias_desc->SetPersistable(true);
phi::DenseTensor tmp_combined_bias_tensor;
tmp_combined_bias_tensor.Resize(combined_bias_dims);
float* tmp_combined_bias_data =
dev_ctx->template HostAlloc<float>(&tmp_combined_bias_tensor);
size_t bias_size = bq_tensor->numel();
memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size);
memcpy(
tmp_combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size);
bq_tensor->clear();
bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data = dev_ctx->template HostAlloc<float>(
bq_tensor, sizeof(float) * bq_tensor->numel());
memcpy(new_combined_bias_data,
tmp_combined_bias_data,
sizeof(float) * bq_tensor->numel());
scope->EraseVars({elementwise1_w->Name()});
multihead_op_desc.SetInput("W_qk", {mul0_w->Name()});
multihead_op_desc.SetInput("W_v", {mul2_w->Name()});
multihead_op_desc.SetInput("B_qk", {elementwise0_w->Name()});
multihead_op_desc.SetInput("B_v", {elementwise2_w->Name()});
multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()});
multihead_op_desc.SetAttr("alpha", scale_attr);
multihead_op_desc.SetAttr("head_number", head_number);
auto* multihead = graph->CreateOpNode(&multihead_op_desc);
IR_NODE_LINK_TO(input0, multihead);
IR_NODE_LINK_TO(input1, multihead);
IR_NODE_LINK_TO(mul0_w, multihead);
IR_NODE_LINK_TO(mul2_w, multihead);
IR_NODE_LINK_TO(elementwise0_w, multihead);
IR_NODE_LINK_TO(elementwise2_w, multihead);
IR_NODE_LINK_TO(multihead, reshape2_qkv_out);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(input1, input1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise0, elementwise0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
elementwise0_w, elementwise0_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
elementwise0_out, elementwise0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0_out, reshape2_0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0_out, transpose2_0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise1, elementwise1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
elementwise1_w, elementwise1_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
elementwise1_out, elementwise1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_1_out, reshape2_1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_1_out, transpose2_1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise2, elementwise2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
elementwise2_w, elementwise2_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
elementwise2_out, elementwise2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_2_out, reshape2_2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_2_out, transpose2_2_out, multihead_pattern);
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(matmul_qk, matmul_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qk_out, matmul_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk_out, softmax_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv_out, matmul_qkv_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv_out, reshape2_qkv_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_qkv, transpose2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_qkv_out, transpose2_qkv_out, multihead_pattern);
fuse_creater(input0,
input1,
mul0,
mul1,
mul2,
mul0_out,
mul1_out,
mul2_out,
mul0_w,
mul1_w,
mul2_w,
elementwise0,
elementwise0_w,
elementwise1,
elementwise1_w,
elementwise2,
elementwise2_w,
reshape2_0,
reshape2_qkv_out,
scale,
scale_out);
std::unordered_set<const Node*> marked_nodes({reshape2_0,
reshape2_1,
reshape2_2,
reshape2_0_out,
reshape2_1_out,
reshape2_2_out,
transpose2_0,
transpose2_1,
transpose2_2,
transpose2_0_out,
transpose2_1_out,
transpose2_2_out,
matmul_qk,
matmul_qk_out,
softmax_qk,
softmax_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
mul0,
mul1,
mul2,
mul0_out,
mul1_out,
mul2_out,
elementwise0,
elementwise0_out,
elementwise1,
elementwise1_w,
elementwise1_out,
elementwise2,
elementwise2_out,
reshape2_qkv,
scale});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
void TrtQkMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
#ifdef PADDLE_WITH_TENSORRT
auto trt_version = paddle::inference::tensorrt::GetTrtRuntimeVersion();
if (std::get<0>(trt_version) * 1000 + std::get<1>(trt_version) * 100 +
std::get<2>(trt_version) * 10 <
8520) {
VLOG(3) << "Qk attention oss plugin only available for trt version >= "
"8.5.2.2. Stop this pass";
return;
}
#else
// if no tensorrt, early stop
return;
#endif
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
if (!with_dynamic_shape) {
VLOG(3) << "Qk attention oss plugin need trt "
"with_dynamic_shape. Stop this pass";
return;
}
auto* scope = param_scope();
int fusion_count = BuildQkFusion(graph, name_scope_, scope);
AddStatis(fusion_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(trt_qk_multihead_matmul_fuse_pass,
paddle::framework::ir::TrtQkMultiHeadMatmulFusePass);
REGISTER_PASS_CAPABILITY(trt_qk_multihead_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.EQ("softmax", 0)
.EQ("matmul_v2", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct TrtQKMultiHeadMatmulPattern : public PatternBase {
TrtQKMultiHeadMatmulPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "qk_multihead_matmul") {}
PDNode* operator()();
// declare operator node's name
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(input1);
PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(mul2);
PATTERN_DECL_NODE(mul0_w);
PATTERN_DECL_NODE(mul1_w);
PATTERN_DECL_NODE(mul2_w);
PATTERN_DECL_NODE(mul0_out);
PATTERN_DECL_NODE(mul1_out);
PATTERN_DECL_NODE(mul2_out);
PATTERN_DECL_NODE(elementwise0);
PATTERN_DECL_NODE(elementwise1);
PATTERN_DECL_NODE(elementwise2);
PATTERN_DECL_NODE(elementwise0_w);
PATTERN_DECL_NODE(elementwise1_w);
PATTERN_DECL_NODE(elementwise2_w);
PATTERN_DECL_NODE(elementwise0_out);
PATTERN_DECL_NODE(elementwise1_out);
PATTERN_DECL_NODE(elementwise2_out);
PATTERN_DECL_NODE(scale);
PATTERN_DECL_NODE(scale_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(transpose2_qkv_out);
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
};
} // namespace patterns
class TrtQkMultiHeadMatmulFusePass : public FusePassBase {
public:
virtual ~TrtQkMultiHeadMatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"trt_qk_multihead_matmul_fuse"};
private:
int BuildQkFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -2569,6 +2569,7 @@ USE_TRT_CONVERTER(preln_groupnorm_act)
#if IS_TRT_VERSION_GE(8522)
USE_TRT_CONVERTER(flash_multihead_matmul)
USE_TRT_CONVERTER(cross_multihead_matmul)
USE_TRT_CONVERTER(qk_multihead_matmul)
#endif
#if IS_TRT_VERSION_GE(8510)
USE_TRT_CONVERTER(grid_sampler)
......
......@@ -108,6 +108,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"trt_flash_multihead_matmul_fuse_pass", //
"trt_cross_multihead_matmul_fuse_pass", //
"vit_attention_fuse_pass", //
"trt_qk_multihead_matmul_fuse_pass", //
"layernorm_shift_partition_fuse_pass", //
"merge_layernorm_fuse_pass", //
#if !defined _WIN32
......
......@@ -28,6 +28,7 @@ list(
multihead_matmul_roformer_op.cc
flash_multihead_matmul_op.cc
cross_multihead_matmul_op.cc
qk_multihead_matmul_op.cc
grid_sampler_op.cc
shuffle_channel_op.cc
fill_any_like_op.cc
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See
the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class QkMultiheadMatMulOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(3) << "convert a qk_multihead_mamul op to a corresponding tensorrt "
"network structure";
framework::OpDesc op_desc(op, nullptr);
auto* input_qk = engine_->GetITensor(op_desc.Input("Input_qk").front());
auto* input_v = engine_->GetITensor(op_desc.Input("Input_v").front());
auto output_name = op_desc.Output("Out")[0];
/* ------------------ weight_qk -------------------------*/
auto weight_qk_name = op_desc.Input("W_qk").front();
auto* weight_qk_v = scope.FindVar(weight_qk_name);
auto* weight_qk_t = weight_qk_v->GetMutable<phi::DenseTensor>();
float* weight_qk_data = nullptr;
weight_qk_data = const_cast<float*>(static_cast<const float*>(
engine_->GetFp32TrtWeight(weight_qk_name, *weight_qk_t).get().values));
const auto& weight_qk_dims =
weight_qk_t->dims(); // hidden_in_qk 2 hidden_out_qk
int hidden_in_qk = weight_qk_dims[0];
int num_qk = weight_qk_dims[1];
int hidden_out_qk = weight_qk_dims[2];
int head_number_qk = PADDLE_GET_CONST(int, op_desc.GetAttr("head_number"));
int head_size_qk = hidden_out_qk / head_number_qk;
int n_qk = num_qk * hidden_out_qk;
// [hidden_in, 2, head_number, head_size]
// -> [head_number, 2, head_size, hidden_in]
auto transpose_weight_qk = [](const float* src,
float* dst,
int two,
int head_number,
int head_size,
int hidden_in) {
for (int hn = 0; hn < head_number; hn++) {
for (int t = 0; t < two; t++) {
for (int hs = 0; hs < head_size; hs++) {
for (int hi = 0; hi < hidden_in; hi++) {
int out_index = hn * two * head_size * hidden_in +
t * head_size * hidden_in + hs * hidden_in + hi;
int in_index = hi * two * head_number * head_size +
t * head_number * head_size + hn * head_size + hs;
dst[out_index] = src[in_index];
}
}
}
}
};
std::vector<float> weight_qk_data_tmp;
weight_qk_data_tmp.reserve(weight_qk_t->numel());
memcpy(weight_qk_data_tmp.data(),
weight_qk_data,
weight_qk_t->numel() * sizeof(float));
transpose_weight_qk(weight_qk_data_tmp.data(),
weight_qk_data,
num_qk,
head_number_qk,
head_size_qk,
hidden_in_qk);
/* ------------------ bias_qk -------------------------*/
auto bias_qk_name = op_desc.Input("B_qk").front();
auto* bias_qk_v = scope.FindVar(bias_qk_name);
auto* bias_qk_t = bias_qk_v->GetMutable<phi::DenseTensor>();
float* bias_qk_data = nullptr;
bias_qk_data = const_cast<float*>(static_cast<const float*>(
engine_->GetFp32TrtWeight(bias_qk_name, *bias_qk_t).get().values));
// [2, head_number, head_size] -> [head_number, 2, head_size]
auto transpose_bias_qk = [](const float* src, float* dst, int N, int H) {
for (int i = 0; i < 2; ++i) {
for (int n = 0; n < N; ++n) {
for (int h = 0; h < H; ++h) {
dst[n * 2 * H + i * H + h] = src[i * N * H + n * H + h];
}
}
}
};
std::vector<float> bias_qk_data_tmp;
bias_qk_data_tmp.reserve(bias_qk_t->numel());
memcpy(bias_qk_data_tmp.data(),
bias_qk_data,
bias_qk_t->numel() * sizeof(float));
transpose_bias_qk(
bias_qk_data_tmp.data(), bias_qk_data, head_number_qk, head_size_qk);
auto weight_qk_shape = nvinfer1::Dims3{1, n_qk, hidden_in_qk};
auto* weight_qk_tensor =
AddConstantLayer(weight_qk_data, weight_qk_shape, " ");
auto bias_qk_shape = nvinfer1::Dims3{1, 1, n_qk};
auto* bias_qk_tensor = AddConstantLayer(bias_qk_data, bias_qk_shape, " ");
nvinfer1::ITensor* input_qk_shape_tensor = Shape(input_qk);
nvinfer1::ILayer* fc_qk_layer = nullptr;
nvinfer1::ILayer* merge_qk_element_layer = nullptr;
nvinfer1::MatrixOperation matrix_operation_X =
nvinfer1::MatrixOperation::kNONE;
nvinfer1::MatrixOperation matrix_operation_Y =
nvinfer1::MatrixOperation::kTRANSPOSE;
fc_qk_layer = TRT_ENGINE_ADD_LAYER(engine_,
MatrixMultiply,
*input_qk,
matrix_operation_X,
*weight_qk_tensor,
matrix_operation_Y);
fc_qk_layer->setName(
("qk_attention_matrix_multiply(Output: " + output_name + ")").c_str());
// add qk ElementWiseLayer layer
nvinfer1::ElementWiseOperation elementwise_operation =
nvinfer1::ElementWiseOperation::kSUM;
merge_qk_element_layer = TRT_ENGINE_ADD_LAYER(engine_,
ElementWise,
*fc_qk_layer->getOutput(0),
*bias_qk_tensor,
elementwise_operation);
merge_qk_element_layer->setName(
("multihead_mamul_fc_qk(Output: " + output_name + ")").c_str());
auto* reshape_after_fc_qk_layer = TRT_ENGINE_ADD_LAYER(
engine_, Shuffle, *merge_qk_element_layer->getOutput(0));
std::vector<nvinfer1::ITensor*> mha_input_qk_tensor_shape;
for (int i = 0; i < 5; i++) {
mha_input_qk_tensor_shape.push_back(Add1DConstantLayer(1));
}
mha_input_qk_tensor_shape[0] =
GetEleTensorOfShape(input_qk_shape_tensor, 0);
mha_input_qk_tensor_shape[1] =
GetEleTensorOfShape(input_qk_shape_tensor, 1);
mha_input_qk_tensor_shape[2] = Add1DConstantLayer(head_number_qk);
mha_input_qk_tensor_shape[3] = Add1DConstantLayer(2);
mha_input_qk_tensor_shape[4] = Add1DConstantLayer(head_size_qk);
reshape_after_fc_qk_layer->setInput(1, *Concat(mha_input_qk_tensor_shape));
reshape_after_fc_qk_layer->setName(
("shuffle_after_fc_qk_multihead_matmul(Output: " + output_name + ")")
.c_str());
/* ------------------ weight_v -------------------------*/
auto weight_v_name = op_desc.Input("W_v").front();
auto* weight_v_v = scope.FindVar(weight_v_name);
auto* weight_v_t = weight_v_v->GetMutable<phi::DenseTensor>();
float* weight_v_data = nullptr;
weight_v_data = const_cast<float*>(static_cast<const float*>(
engine_->GetFp32TrtWeight(weight_v_name, *weight_v_t).get().values));
int n_v = hidden_out_qk;
// [hidden_in, head_number, head_size]
// -> [head_number, head_size, hidden_in]
auto transpose_weight_v = [](const float* src,
float* dst,
int head_number,
int head_size,
int hidden_in) {
for (int hn = 0; hn < head_number; hn++) {
for (int hs = 0; hs < head_size; hs++) {
for (int hi = 0; hi < hidden_in; hi++) {
int out_index = hn * head_size * hidden_in + hs * hidden_in + hi;
int in_index = hi * head_number * head_size + hn * head_size + hs;
dst[out_index] = src[in_index];
}
}
}
};
std::vector<float> weight_v_data_tmp;
weight_v_data_tmp.reserve(weight_v_t->numel());
memcpy(weight_v_data_tmp.data(),
weight_v_data,
weight_v_t->numel() * sizeof(float));
transpose_weight_v(weight_v_data_tmp.data(),
weight_v_data,
head_number_qk,
head_size_qk,
hidden_in_qk);
/* ------------------ bias_v -------------------------*/
auto bias_v_name = op_desc.Input("B_v").front();
auto* bias_v_v = scope.FindVar(bias_v_name);
auto* bias_v_t = bias_v_v->GetMutable<phi::DenseTensor>();
float* bias_v_data = nullptr;
bias_v_data = const_cast<float*>(static_cast<const float*>(
engine_->GetFp32TrtWeight(bias_v_name, *bias_v_t).get().values));
auto weight_v_shape = nvinfer1::Dims3{1, n_v, hidden_in_qk};
auto* weight_v_tensor =
AddConstantLayer(weight_v_data, weight_v_shape, " ");
auto bias_v_shape = nvinfer1::Dims3{1, 1, n_v};
auto* bias_v_tensor = AddConstantLayer(bias_v_data, bias_v_shape, " ");
nvinfer1::ITensor* input_v_shape_tensor = Shape(input_v);
nvinfer1::ILayer* fc_v_layer = nullptr;
nvinfer1::ILayer* merge_v_element_layer = nullptr;
fc_v_layer = TRT_ENGINE_ADD_LAYER(engine_,
MatrixMultiply,
*input_v,
matrix_operation_X,
*weight_v_tensor,
matrix_operation_Y);
fc_v_layer->setName(
("v_attention_matrix_multiply(Output: " + output_name + ")").c_str());
// add v ElementWiseLayer layer
merge_v_element_layer = TRT_ENGINE_ADD_LAYER(engine_,
ElementWise,
*fc_v_layer->getOutput(0),
*bias_v_tensor,
elementwise_operation);
merge_v_element_layer->setName(
("multihead_mamul_fc_v(Output: " + output_name + ")").c_str());
// add shuffle for fc layer
auto* reshape_after_fc_v_layer = TRT_ENGINE_ADD_LAYER(
engine_, Shuffle, *merge_v_element_layer->getOutput(0));
std::vector<nvinfer1::ITensor*> mha_input_v_tensor_shape;
for (int i = 0; i < 5; i++) {
mha_input_v_tensor_shape.push_back(Add1DConstantLayer(1));
}
mha_input_v_tensor_shape[0] = GetEleTensorOfShape(input_v_shape_tensor, 0);
mha_input_v_tensor_shape[1] = GetEleTensorOfShape(input_v_shape_tensor, 1);
mha_input_v_tensor_shape[2] = Add1DConstantLayer(head_number_qk);
mha_input_v_tensor_shape[3] = Add1DConstantLayer(1);
mha_input_v_tensor_shape[4] = Add1DConstantLayer(head_size_qk);
reshape_after_fc_v_layer->setInput(1, *Concat(mha_input_v_tensor_shape));
reshape_after_fc_v_layer->setName(
("shuffle_after_fc_v_multihead_matmul(Output: " + output_name + ")")
.c_str());
std::vector<nvinfer1::ITensor*> mha_input_tensor_vector{
reshape_after_fc_qk_layer->getOutput(0),
reshape_after_fc_v_layer->getOutput(0)};
nvinfer1::ITensor* mha_input_tensor = Concat(mha_input_tensor_vector, 3);
auto creator = GetPluginRegistry()->getPluginCreator("fMHA_V2", "1");
assert(creator != nullptr);
std::vector<nvinfer1::PluginField> fields{};
nvinfer1::PluginFieldCollection* plugin_collection =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*plugin_collection) +
fields.size() *
sizeof(nvinfer1::PluginField))); // remember to free
plugin_collection->nbFields = static_cast<int>(fields.size());
plugin_collection->fields = fields.data();
auto plugin = creator->createPlugin("fMHA_V2", plugin_collection);
free(plugin_collection);
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(mha_input_tensor);
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
// add shuffle
nvinfer1::ITensor* batch_tensor =
GetEleTensorOfShape(input_qk_shape_tensor, 0);
nvinfer1::ITensor* length_tensor =
GetEleTensorOfShape(input_qk_shape_tensor, 1);
auto* reshape_after_mha_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0));
std::vector<nvinfer1::ITensor*> reshape_tensor;
reshape_tensor.push_back(batch_tensor);
reshape_tensor.push_back(length_tensor);
reshape_tensor.push_back(Add1DConstantLayer(-1));
reshape_after_mha_layer->setInput(1, *Concat(reshape_tensor));
reshape_after_mha_layer->setName(
("shuffle_last_multihead_matmul(Output: " + output_name + ")").c_str());
nvinfer1::ILayer* layer = nullptr;
layer = reshape_after_mha_layer;
RreplenishLayerAndOutput(
layer, "qk_multihead_matmul", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(qk_multihead_matmul, QkMultiheadMatMulOpConverter);
......@@ -70,6 +70,8 @@ struct SimpleOpTypeSetTeller : public Teller {
int8_teller_set.insert("flash_multihead_matmul");
teller_set.insert("cross_multihead_matmul");
int8_teller_set.insert("cross_multihead_matmul");
teller_set.insert("qk_multihead_matmul");
int8_teller_set.insert("qk_multihead_matmul");
#endif
#if IS_TRT_VERSION_GE(8200)
teller_set.insert("round");
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
from typing import List
import numpy as np
from program_config import ProgramConfig, TensorConfig
from trt_layer_auto_scan_test import SkipReasons, TrtLayerAutoScanTest
import paddle.inference as paddle_infer
class TrtConvertQkAttentionTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8520:
return False
return True
def sample_program_configs(self):
def generate_input1(batch, length):
return np.random.rand(batch, length, 256).astype(np.float32) / 10
def generate_input2(batch, length):
return np.random.rand(batch, length, 256).astype(np.float32) / 10
def generate_weight_q():
return np.random.rand(256, 256).astype(np.float32) / 10
def generate_weight_k():
return np.random.rand(256, 256).astype(np.float32) / 10
def generate_weight_v():
return np.random.rand(256, 256).astype(np.float32) / 10
def generate_bias_q():
return np.random.rand(256).astype(np.float32) / 10
def generate_bias_k():
return np.random.rand(256).astype(np.float32) / 10
def generate_bias_v():
return np.random.rand(256).astype(np.float32) / 10
for batch in [1, 2]:
self.batch = batch
for length in [300, 400]:
ops_config = [
# q
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["input_data1"],
"Y": ["matmul_q_weight"],
},
"op_outputs": {"Out": ["matmul_q_output"]},
"op_attrs": {"trans_x": False, "trans_y": False},
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["matmul_q_output"],
"Y": ["bias_q"],
},
"op_outputs": {"Out": ["elementwise_q_output"]},
"op_attrs": {
"Scale_out": 1.0,
"Scale_x": 1.0,
"Scale_y": 1.0,
"axis": 2,
},
},
{
"op_type": "reshape2",
"op_inputs": {
"X": ["elementwise_q_output"],
},
"op_outputs": {
"Out": ["reshape_q_output"],
"XShape": ["reshape_q_output_xshape"],
},
"op_attrs": {"shape": [0, 0, 8, 32]},
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["reshape_q_output"]},
"op_outputs": {
"Out": ["transpose_q_output"],
"XShape": ["transpose_q_output_xshape"],
},
"op_attrs": {
"axis": [0, 2, 1, 3],
"data_format": "AnyLayout",
},
},
# k
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["input_data1"],
"Y": ["matmul_k_weight"],
},
"op_outputs": {"Out": ["matmul_k_output"]},
"op_attrs": {"trans_x": False, "trans_y": False},
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["matmul_k_output"],
"Y": ["bias_k"],
},
"op_outputs": {"Out": ["elementwise_k_output"]},
"op_attrs": {
"Scale_out": 1.0,
"Scale_x": 1.0,
"Scale_y": 1.0,
"axis": 2,
},
},
{
"op_type": "reshape2",
"op_inputs": {
"X": ["elementwise_k_output"],
},
"op_outputs": {
"Out": ["reshape_k_output"],
"XShape": ["reshape_k_output_xshape"],
},
"op_attrs": {"shape": [0, 0, 8, 32]},
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["reshape_k_output"]},
"op_outputs": {
"Out": ["transpose_k_output"],
"XShape": ["transpose_k_output_xshape"],
},
"op_attrs": {
"axis": [0, 2, 1, 3],
"data_format": "AnyLayout",
},
},
# V
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["input_data2"],
"Y": ["matmul_v_weight"],
},
"op_outputs": {"Out": ["matmul_v_output"]},
"op_attrs": {"trans_x": False, "trans_y": False},
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["matmul_v_output"],
"Y": ["bias_v"],
},
"op_outputs": {"Out": ["elementwise_v_output"]},
"op_attrs": {
"Scale_out": 1.0,
"Scale_x": 1.0,
"Scale_y": 1.0,
"axis": 2,
},
},
{
"op_type": "reshape2",
"op_inputs": {
"X": ["elementwise_v_output"],
},
"op_outputs": {
"Out": ["reshape_v_output"],
"XShape": ["reshape_v_output_xshape"],
},
"op_attrs": {"shape": [0, 0, 8, 32]},
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["reshape_v_output"]},
"op_outputs": {
"Out": ["transpose_v_output"],
"XShape": ["transpose_v_output_xshape"],
},
"op_attrs": {
"axis": [0, 2, 1, 3],
"data_format": "AnyLayout",
},
},
# matmul1+matmul2
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["transpose_q_output"],
"Y": ["transpose_k_output"],
},
"op_outputs": {"Out": ["matmul1_output"]},
"op_attrs": {"trans_x": False, "trans_y": True},
},
{
"op_type": "scale",
"op_inputs": {
"X": ["matmul1_output"],
},
"op_outputs": {"Out": ["scale_output"]},
"op_attrs": {
"scale": 0.17677,
"bias": 0.0,
"bias_after_scale": True,
},
},
{
"op_type": "softmax",
"op_inputs": {"X": ["scale_output"]},
"op_outputs": {"Out": ["softmax_output"]},
"op_attrs": {
"axis": -1,
"data_format": "AnyLayout",
},
},
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["softmax_output"],
"Y": ["transpose_v_output"],
},
"op_outputs": {"Out": ["matmul2_output"]},
"op_attrs": {"trans_x": False, "trans_y": False},
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["matmul2_output"]},
"op_outputs": {
"Out": ["transpose_output"],
"XShape": ["transpose_output_xshape"],
},
"op_attrs": {
"axis": [0, 2, 1, 3],
"data_format": "AnyLayout",
},
},
{
"op_type": "reshape2",
"op_inputs": {"X": ["transpose_output"]},
"op_outputs": {
"Out": ["reshape_output"],
"XShape": ["reshape_output_xshape"],
},
"op_attrs": {"shape": [0, 0, 256]},
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"matmul_q_weight": TensorConfig(
data_gen=partial(generate_weight_q)
),
"matmul_k_weight": TensorConfig(
data_gen=partial(generate_weight_k)
),
"matmul_v_weight": TensorConfig(
data_gen=partial(generate_weight_v)
),
"bias_q": TensorConfig(
data_gen=partial(generate_bias_q)
),
"bias_k": TensorConfig(
data_gen=partial(generate_bias_k)
),
"bias_v": TensorConfig(
data_gen=partial(generate_bias_v)
),
},
inputs={
"input_data1": TensorConfig(
data_gen=partial(generate_input1, batch, length)
),
"input_data2": TensorConfig(
data_gen=partial(generate_input2, batch, length)
),
},
outputs=["reshape_output"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
# The last dim of input1 and input2 should be static.
self.dynamic_shape.min_input_shape = {
"input_data1": [1, 300, 256],
"input_data2": [1, 300, 256],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [4, 1200, 256],
"input_data2": [4, 1200, 256],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [1, 300, 256],
"input_data2": [1, 300, 256],
}
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
self.trt_param.workspace_size = 2013265920
yield self.create_inference_config(), (1, 3), (1e-5, 1e-5)
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 3), (1e-3, 1e-3)
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
self.trt_param.workspace_size = 2013265920
yield self.create_inference_config(), (1, 3), (1e-5, 1e-4)
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 3), (1e-2, 1e-3)
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if self.dynamic_shape.min_input_shape == {}:
return True
return False
self.add_skip_case(
teller1,
SkipReasons.TRT_NOT_IMPLEMENTED,
"The qk attention trt oss plugin do not support static shape yet",
)
def teller2(program_config, predictor_config):
if self.trt_param.precision == paddle_infer.PrecisionType.Float32:
return True
return False
self.add_skip_case(
teller2,
SkipReasons.TRT_NOT_IMPLEMENTED,
"The qk attention trt oss plugin do not support fp32 yet",
)
def teller3(program_config, predictor_config):
if self.trt_param.precision == paddle_infer.PrecisionType.Int8:
return True
return False
self.add_skip_case(
teller3,
SkipReasons.TRT_NOT_IMPLEMENTED,
"The qk attention trt oss plugin do not support int8 yet.",
)
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册