未验证 提交 f706d95d 编写于 作者: F feng_shuai 提交者: GitHub

convert multihead to oss (#45019)

* convert multihead to oss

* fix:bug

* fix:delete const cast

* fix:don't support bias_qk

* add vit pass

* fix:convert bug and add preln_residual_bias

* support length=-1

* add UT for convert

* add no_bias_qk support for gpu_multihead_op

* delete infer_shape depends on bias_qk

* oss just can be used in T4 and A*

* fix:change api for ROCM CI
上级 a0bbfbd4
...@@ -117,6 +117,7 @@ pass_library(graph_viz_pass base) ...@@ -117,6 +117,7 @@ pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base DEPS string_helper) pass_library(lock_free_optimize_pass base DEPS string_helper)
pass_library(fc_fuse_pass inference) pass_library(fc_fuse_pass inference)
pass_library(attention_lstm_fuse_pass inference) pass_library(attention_lstm_fuse_pass inference)
pass_library(vit_attention_fuse_pass inference)
pass_library(fc_lstm_fuse_pass inference) pass_library(fc_lstm_fuse_pass inference)
pass_library(embedding_fc_lstm_fuse_pass inference) pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference) pass_library(fc_gru_fuse_pass inference)
......
...@@ -2334,6 +2334,154 @@ PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) { ...@@ -2334,6 +2334,154 @@ PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) {
return act_out; return act_out;
} }
PDNode *patterns::VitAttention::operator()(PDNode *in) {
in->AsInput();
std::unordered_set<std::string> matmul_ops{"matmul", "matmul_v2"};
auto matmul0_op =
pattern->NewNode(matmul0_op_repr())->assert_is_ops(matmul_ops);
auto matmul0_in_y = pattern->NewNode(matmul0_in_y_repr())
->AsInput()
->assert_is_ops_input(matmul_ops, "Y");
auto matmul0_out = pattern->NewNode(matmul0_out_repr())
->assert_is_ops_output(matmul_ops, "Out")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto elementwise0_op =
pattern->NewNode(elementwise0_op_repr())->assert_is_op("elementwise_add");
auto elementwise0_in_y = pattern->NewNode(elementwise0_in_y_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto elementwise0_out = pattern->NewNode(elementwise0_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("reshape2", "X")
->AsIntermediate();
auto reshape1_op =
pattern->NewNode(reshape1_op_repr())->assert_is_op("reshape2");
auto reshape1_out = pattern->NewNode(reshape1_out_repr())
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input("transpose2", "X")
->AsIntermediate();
auto transpose1_op =
pattern->NewNode(transpose1_op_repr())->assert_is_op("transpose2");
auto transpose1_out = pattern->NewNode(transpose1_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input("slice", "Input")
->AsIntermediate();
auto slice1_op = pattern->NewNode(slice1_op_repr())->assert_is_op("slice");
auto slice1_out = pattern->NewNode(slice1_out_repr())
->assert_is_op_output("slice", "Out")
->assert_is_op_input("matmul_v2", "Y")
->AsIntermediate();
auto slice2_op = pattern->NewNode(slice2_op_repr())->assert_is_op("slice");
auto slice2_out = pattern->NewNode(slice2_out_repr())
->assert_is_op_output("slice", "Out")
->assert_is_op_input("matmul_v2", "X")
->AsIntermediate();
auto slice3_op = pattern->NewNode(slice3_op_repr())->assert_is_op("slice");
auto slice3_out = pattern->NewNode(slice3_out_repr())
->assert_is_op_output("slice", "Out")
->assert_is_op_input("transpose2", "X")
->AsIntermediate();
auto transpose2_op =
pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2");
auto transpose2_out = pattern->NewNode(transpose2_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input("matmul_v2", "Y")
->AsIntermediate();
auto matmul1_op =
pattern->NewNode(matmul1_op_repr())->assert_is_op("matmul_v2");
auto matmul1_out = pattern->NewNode(matmul1_out_repr())
->assert_is_op_output("matmul_v2", "Out")
->assert_is_op_input("scale", "X")
->AsIntermediate();
auto scale1_op = pattern->NewNode(scale1_op_repr())->assert_is_op("scale");
auto scale1_out = pattern->NewNode(scale1_out_repr())
->assert_is_op_output("scale", "Out")
->assert_is_op_input("softmax", "X")
->AsIntermediate();
auto softmax1_op =
pattern->NewNode(softmax1_op_repr())->assert_is_op("softmax");
auto softmax1_out = pattern->NewNode(softmax1_out_repr())
->assert_is_op_output("softmax", "Out")
->assert_is_op_input("matmul_v2", "X")
->AsIntermediate();
auto matmul2_op =
pattern->NewNode(matmul2_op_repr())->assert_is_op("matmul_v2");
auto matmul2_out = pattern->NewNode(matmul2_out_repr())
->assert_is_op_output("matmul_v2", "Out")
->assert_is_op_input("transpose2", "X")
->AsIntermediate();
auto transpose3_op =
pattern->NewNode(transpose3_op_repr())->assert_is_op("transpose2");
auto transpose3_out = pattern->NewNode(transpose3_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input("reshape2", "X")
->AsIntermediate();
auto reshape2_op =
pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2");
auto reshape2_out = pattern->NewNode(reshape2_out_repr())
->assert_is_op_output("reshape2", "Out")
->AsOutput();
matmul0_op->LinksFrom({in, matmul0_in_y});
matmul0_out->LinksFrom({matmul0_op});
elementwise0_op->LinksFrom({matmul0_out, elementwise0_in_y});
elementwise0_out->LinksFrom({elementwise0_op});
reshape1_op->LinksFrom({elementwise0_out});
reshape1_out->LinksFrom({reshape1_op});
transpose1_op->LinksFrom({reshape1_out});
transpose1_out->LinksFrom({transpose1_op});
slice1_op->LinksFrom({transpose1_out});
slice1_out->LinksFrom({slice1_op});
slice2_op->LinksFrom({transpose1_out});
slice2_out->LinksFrom({slice2_op});
slice3_op->LinksFrom({transpose1_out});
slice3_out->LinksFrom({slice3_op});
transpose2_op->LinksFrom({slice3_out});
transpose2_out->LinksFrom({transpose2_op});
matmul1_op->LinksFrom({slice2_out, transpose2_out});
matmul1_out->LinksFrom({matmul1_op});
scale1_op->LinksFrom({matmul1_out});
scale1_out->LinksFrom({scale1_op});
softmax1_op->LinksFrom({scale1_out});
softmax1_out->LinksFrom({softmax1_op});
matmul2_op->LinksFrom({slice1_out, softmax1_out});
matmul2_out->LinksFrom({matmul2_op});
transpose3_op->LinksFrom({matmul2_out});
transpose3_out->LinksFrom({transpose3_op});
reshape2_op->LinksFrom({transpose3_out});
reshape2_out->LinksFrom({reshape2_op});
return reshape2_out;
}
PDNode *patterns::ConvElementwiseadd2Act::operator()(PDNode *conv_in) { PDNode *patterns::ConvElementwiseadd2Act::operator()(PDNode *conv_in) {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto conv_filter = pattern->NewNode(conv_filter_repr()) auto conv_filter = pattern->NewNode(conv_filter_repr())
......
...@@ -1377,6 +1377,58 @@ struct PriorBox : public PatternBase { ...@@ -1377,6 +1377,58 @@ struct PriorBox : public PatternBase {
PATTERN_DECL_NODE(prior_box_variances); PATTERN_DECL_NODE(prior_box_variances);
}; };
// vit_attention
struct VitAttention : public PatternBase {
VitAttention(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "vit_attention") {}
PDNode* operator()(PDNode* in);
PATTERN_DECL_NODE(matmul0_op);
PATTERN_DECL_NODE(matmul0_in_y);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(elementwise0_op);
PATTERN_DECL_NODE(elementwise0_in_y);
PATTERN_DECL_NODE(elementwise0_out);
PATTERN_DECL_NODE(reshape1_op);
PATTERN_DECL_NODE(reshape1_out);
PATTERN_DECL_NODE(transpose1_op);
PATTERN_DECL_NODE(transpose1_out);
PATTERN_DECL_NODE(slice1_op);
PATTERN_DECL_NODE(slice1_out);
PATTERN_DECL_NODE(slice2_op);
PATTERN_DECL_NODE(slice2_out);
PATTERN_DECL_NODE(slice3_op);
PATTERN_DECL_NODE(slice3_out);
PATTERN_DECL_NODE(matmul2_op);
PATTERN_DECL_NODE(matmul2_out);
PATTERN_DECL_NODE(matmul1_op);
PATTERN_DECL_NODE(matmul1_out);
PATTERN_DECL_NODE(transpose2_op);
PATTERN_DECL_NODE(transpose2_out);
PATTERN_DECL_NODE(scale1_op);
PATTERN_DECL_NODE(scale1_out);
PATTERN_DECL_NODE(softmax1_op);
PATTERN_DECL_NODE(softmax1_out);
PATTERN_DECL_NODE(transpose3_op);
PATTERN_DECL_NODE(transpose3_out);
PATTERN_DECL_NODE(reshape2_op);
PATTERN_DECL_NODE(reshape2_out);
};
// Conv + ElementwiseAdd + an activation // Conv + ElementwiseAdd + an activation
// This pattern can further fuse the conv related ops after the conv+bn fusion. // This pattern can further fuse the conv related ops after the conv+bn fusion.
struct ConvElementwiseaddAct : public PatternBase { struct ConvElementwiseaddAct : public PatternBase {
......
...@@ -45,7 +45,7 @@ class PrelnResidualBiasFusePass : public FusePassBase { ...@@ -45,7 +45,7 @@ class PrelnResidualBiasFusePass : public FusePassBase {
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("axis") .AddAttr("axis")
.IsIntIn({0, -1}) .IsIntIn({0, -1, 2})
.End(); .End();
AddOpCompat(OpCompat("layer_norm")) AddOpCompat(OpCompat("layer_norm"))
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/vit_attention_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(matmul0_op); \
GET_IR_NODE(matmul0_in_y); \
GET_IR_NODE(matmul0_out); \
GET_IR_NODE(elementwise0_op); \
GET_IR_NODE(elementwise0_in_y); \
GET_IR_NODE(elementwise0_out); \
GET_IR_NODE(reshape1_op); \
GET_IR_NODE(reshape1_out); \
GET_IR_NODE(transpose1_op); \
GET_IR_NODE(transpose1_out); \
GET_IR_NODE(slice1_op); \
GET_IR_NODE(slice1_out); \
GET_IR_NODE(slice2_op); \
GET_IR_NODE(slice2_out); \
GET_IR_NODE(slice3_op); \
GET_IR_NODE(slice3_out); \
GET_IR_NODE(matmul1_op); \
GET_IR_NODE(matmul1_out); \
GET_IR_NODE(scale1_op); \
GET_IR_NODE(scale1_out); \
GET_IR_NODE(transpose2_op); \
GET_IR_NODE(transpose2_out); \
GET_IR_NODE(softmax1_op); \
GET_IR_NODE(softmax1_out); \
GET_IR_NODE(matmul2_op); \
GET_IR_NODE(matmul2_out); \
GET_IR_NODE(transpose3_op); \
GET_IR_NODE(transpose3_out); \
GET_IR_NODE(reshape2_op); \
GET_IR_NODE(reshape2_out);
namespace paddle {
namespace framework {
namespace ir {
void VitAttentionFusePass::ApplyImpl(ir::Graph* graph) const {
GraphPatternDetector gpd;
const std::string pattern_name = "vit_attention_fuse";
FusePassBase::Init(pattern_name, graph);
auto* scope = param_scope();
// pattern
std::unordered_set<std::string> matmul_ops{"matmul", "matmul_v2"};
PDNode* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_ops_input(matmul_ops, "X")
->AsInput();
patterns::VitAttention pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
int fusion_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
// do something;
OpDesc desc(matmul0_op->Op()->Block());
desc.SetType("multihead_matmul");
desc.SetInput("Input", {subgraph.at(x)->Name()});
// refactor W and Bias
auto* w_tensor =
scope->FindVar(matmul0_in_y->Name())->GetMutable<LoDTensor>();
auto w_dims =
phi::make_ddim({w_tensor->dims()[0], 3, w_tensor->dims()[1] / 3});
w_tensor->Resize(w_dims);
auto* b_tensor =
scope->FindVar(elementwise0_in_y->Name())->GetMutable<LoDTensor>();
auto bias_dims = phi::make_ddim({3, b_tensor->dims()[0] / 3});
b_tensor->Resize(bias_dims);
desc.SetInput("W", {matmul0_in_y->Name()});
desc.SetInput("Bias", {elementwise0_in_y->Name()});
std::vector<int64_t> shape = softmax1_out->Var()->GetShape();
desc.SetOutput("Out", {reshape2_out->Name()});
desc.SetAttr("head_number", static_cast<int>(shape[1]));
float alpha = PADDLE_GET_CONST(float, scale1_op->Op()->GetAttr("scale"));
desc.SetAttr("alpha", alpha);
// Create a new node for the fused op.
auto vit_attention_node = graph->CreateOpNode(&desc);
// Link inputs and outputs.
PADDLE_ENFORCE_NE(
subgraph.count(x),
0,
platform::errors::NotFound("Detector did not find input x of conv2d."));
IR_NODE_LINK_TO(subgraph.at(x), vit_attention_node); // Input
IR_NODE_LINK_TO(matmul0_in_y, vit_attention_node);
IR_NODE_LINK_TO(elementwise0_in_y, vit_attention_node);
IR_NODE_LINK_TO(vit_attention_node, reshape2_out); // Output
// Delete the unneeded nodes.
std::unordered_set<const Node*> marked_nodes(
{matmul0_op, matmul0_out, elementwise0_op, elementwise0_out,
reshape1_op, reshape1_out, transpose1_op, transpose1_out,
slice1_op, slice1_out, slice2_op, slice2_out,
slice3_op, slice3_out, matmul1_op, matmul1_out,
scale1_op, scale1_out, transpose2_op, transpose2_out,
softmax1_op, softmax1_out, matmul2_op, matmul2_out,
transpose3_op, transpose3_out, reshape2_op});
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
AddStatis(fusion_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(vit_attention_fuse_pass,
paddle::framework::ir::VitAttentionFusePass);
REGISTER_PASS_CAPABILITY(vit_attention_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("reshape2", 0)
.GE("transpose2", 0)
.GE("slice", 0)
.GE("scale", 0)
.GE("softmax", 0)
.GE("matmul_v2", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
// Fusing of vit_attention structure
class Graph;
class VitAttentionFusePass : public FusePassBase {
public:
virtual ~VitAttentionFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -101,6 +101,7 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -101,6 +101,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"delete_c_identity_op_pass", // "delete_c_identity_op_pass", //
"trt_multihead_matmul_fuse_pass_v2", // "trt_multihead_matmul_fuse_pass_v2", //
"trt_multihead_matmul_fuse_pass_v3", // "trt_multihead_matmul_fuse_pass_v3", //
"vit_attention_fuse_pass", //
"trt_skip_layernorm_fuse_pass", // "trt_skip_layernorm_fuse_pass", //
"preln_skip_layernorm_fuse_pass", // "preln_skip_layernorm_fuse_pass", //
"preln_residual_bias_fuse_pass", // "preln_residual_bias_fuse_pass", //
......
...@@ -29,6 +29,10 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -29,6 +29,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
// Declare inputs // Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("Input").front()); auto* input = engine_->GetITensor(op_desc.Input("Input").front());
auto input_dims = input->getDimensions();
bool bias_qk_attr =
(op_desc.Inputs().find("BiasQK") == op_desc.Inputs().end()) ? false
: true;
// fc weights and fc bias // fc weights and fc bias
auto weight_name = op_desc.Input("W").front(); auto weight_name = op_desc.Input("W").front();
...@@ -287,6 +291,234 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -287,6 +291,234 @@ class MultiheadMatMulOpConverter : public OpConverter {
plugin_inputs.data(), plugin_inputs.size(), *plugin); plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer; layer = plugin_layer;
} }
}
if (input_dims.d[1] <= 384 && !bias_qk_attr &&
engine_->precision() != AnalysisConfig::Precision::kFloat32) {
/*
* input_dims.d[0]: batch(-1)
* input_dims.d[1]: length:256
* input_dims.d[2]: hidden_size:768
input
|[b,256,768]
|
shuffle weight bias
|[b,256,768,1,1] | |
|_____________________|_________|
|
fc
|[b,256,2304,1,1]
|
shuffle mask(fake) pos max_length
|[b*256,2304,1,1] | | |
| | | |
|_______________________|_________|________|
|
MHA
|[b*256,768]
|
shuffle
|[b, 256, 768]
|
out
*/
nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<int32_t>(weight_t->numel())};
nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_data),
static_cast<int32_t>(bias_t->numel())};
/*** transpose the weight and bias ***/
int head_size = hidden_out / head_number;
// [3, head_number, head_size, hidden_in] -> [head_number, 3,
// head_size, hidden_in]
auto transpose_weight_v2 = [](const float* src,
float* dst,
int three,
int head_number,
int head_size,
int hidden_in) {
const int HH = head_size * hidden_in;
for (int i = 0; i < three; ++i) {
for (int n = 0; n < head_number; ++n) {
for (int hh = 0; hh < HH; ++hh) {
dst[n * three * HH + i * HH + hh] =
src[i * head_number * HH + n * HH + hh];
}
}
}
};
// [3, head_number, head_size] -> [head_number, 3, head_size]
auto transpose_bias_v2 =
[](const float* src, float* dst, int N, int H) {
for (int i = 0; i < 3; ++i) {
for (int n = 0; n < N; ++n) {
for (int h = 0; h < H; ++h) {
dst[n * 3 * H + i * H + h] = src[i * N * H + n * H + h];
}
}
}
};
memcpy(weight_data_tmp.data(),
weight_data,
weight_t->numel() * sizeof(float));
transpose_weight_v2(weight_data_tmp.data(),
weight_data,
three,
head_number,
head_size,
hidden_in);
std::vector<float> bias_data_tmp;
bias_data_tmp.reserve(bias_t->numel());
memcpy(
bias_data_tmp.data(), bias_data, bias_t->numel() * sizeof(float));
transpose_bias_v2(
bias_data_tmp.data(), bias_data, head_number, head_size);
// add shuffle for FullyConnected layer
std::vector<nvinfer1::ITensor*> reshape_before_fc_shape_tensor;
nvinfer1::ITensor* input_shape_tensor = Shape(input);
for (int i = 0; i < 5; i++) {
reshape_before_fc_shape_tensor.push_back(Add1DConstantLayer(1));
}
for (int i = 0; i < 3; i++) {
reshape_before_fc_shape_tensor[i] =
GetEleTensorOfShape(input_shape_tensor, i);
}
auto* reshape_before_fc_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
reshape_before_fc_layer->setInput(
1, *Concat(reshape_before_fc_shape_tensor));
reshape_before_fc_layer->setName(
("shuffle_before_fc_multihead_matmul(Output: " + output_name + ")")
.c_str());
// add fc layer
nvinfer1::ILayer* fc_layer = nullptr;
fc_layer = TRT_ENGINE_ADD_LAYER(engine_,
FullyConnected,
*reshape_before_fc_layer->getOutput(0),
n,
weight,
bias);
// add shuffle for CustomQKVToContextPluginDynamic layer
auto* reshape_after_fc_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *fc_layer->getOutput(0));
std::vector<nvinfer1::ITensor*> mha_input_tensor_shape;
mha_input_tensor_shape.push_back(Add1DConstantLayer(-1));
mha_input_tensor_shape.push_back(
Add1DConstantLayer(hidden_out * 3)); // Q,K,V
mha_input_tensor_shape.push_back(Add1DConstantLayer(1));
mha_input_tensor_shape.push_back(Add1DConstantLayer(1));
reshape_after_fc_layer->setInput(1, *Concat(mha_input_tensor_shape));
reshape_after_fc_layer->setName(
("shuffle_after_fc_multihead_matmul(Output: " + output_name + ")")
.c_str());
// add mha_plugin
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomQKVToContextPluginDynamic", "2");
assert(creator != nullptr);
// set the attributes of mha_plugin
int type = static_cast<int>(nvinfer1::DataType::kHALF);
int var_seqlen = 1;
bool has_mask = true;
std::vector<nvinfer1::PluginField> fields{
{"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, 1},
{"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1},
{"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
{"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1},
{"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1}};
nvinfer1::PluginFieldCollection* plugin_collection =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*plugin_collection) +
fields.size() *
sizeof(nvinfer1::PluginField))); // remember to free
plugin_collection->nbFields = static_cast<int>(fields.size());
plugin_collection->fields = fields.data();
auto plugin = creator->createPlugin("CustomQKVToContextPluginDynamic",
plugin_collection);
free(plugin_collection);
// set inputs
std::vector<nvinfer1::ITensor*> plugin_inputs;
// input_0 for plugin
plugin_inputs.emplace_back(reshape_after_fc_layer->getOutput(0));
// input_1(fake) for plugin
std::vector<int> mask = {1};
nvinfer1::ITensor* mask_tensor = Add1DConstantLayer(mask);
plugin_inputs.emplace_back(mask_tensor);
// input_2 for plugin
std::vector<int> pos_id = {0};
int max_batch = 500;
for (int i = 1; i < max_batch; i++) {
pos_id.push_back(i);
}
nvinfer1::ITensor* fake_pos_id_tensor = Add1DConstantLayer(pos_id);
nvinfer1::ITensor* length_tensor =
GetEleTensorOfShape(input_shape_tensor, 1);
auto pos_id_layer =
TRT_ENGINE_ADD_LAYER(engine_,
ElementWise,
*fake_pos_id_tensor,
*length_tensor,
nvinfer1::ElementWiseOperation::kPROD);
// size = batch + 1;
nvinfer1::ITensor* batch_tensor =
GetEleTensorOfShape(input_shape_tensor, 0);
std::vector<int> const_data = {1};
nvinfer1::ITensor* const_tensor = Add1DConstantLayer(const_data);
auto size_layer =
TRT_ENGINE_ADD_LAYER(engine_,
ElementWise,
*batch_tensor,
*const_tensor,
nvinfer1::ElementWiseOperation::kSUM);
// get size(batch + 1) data from pos_id_tensor
nvinfer1::Dims start;
nvinfer1::Dims stride;
nvinfer1::Dims size;
start.nbDims = 1;
stride.nbDims = 1;
size.nbDims = 1;
start.d[0] = 0;
stride.d[0] = 1;
size.d[0] = 1;
auto* slice_pos_layer = TRT_ENGINE_ADD_LAYER(
engine_, Slice, *pos_id_layer->getOutput(0), start, size, stride);
slice_pos_layer->setInput(2, *size_layer->getOutput(0));
plugin_inputs.emplace_back(slice_pos_layer->getOutput(0));
// input_3 for plugin
std::vector<int> data(500, 1);
nvinfer1::ITensor* fake_max_seqlen_tensor = Add1DConstantLayer(data);
auto* slice_max_layer = TRT_ENGINE_ADD_LAYER(
engine_, Slice, *fake_max_seqlen_tensor, start, size, stride);
slice_max_layer->setInput(2, *length_tensor);
plugin_inputs.emplace_back(slice_max_layer->getOutput(0));
// plugin_layer
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
// add shuffle
auto* reshape_after_mha_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0));
std::vector<nvinfer1::ITensor*> reshape_tensor;
reshape_tensor.push_back(batch_tensor);
reshape_tensor.push_back(length_tensor);
reshape_tensor.push_back(Add1DConstantLayer(-1));
reshape_after_mha_layer->setInput(1, *Concat(reshape_tensor));
reshape_after_mha_layer->setName(
("shuffle_last_multihead_matmul(Output: " + output_name + ")")
.c_str());
// return
layer = reshape_after_mha_layer;
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input->getDimensions().nbDims, input->getDimensions().nbDims,
......
...@@ -1820,7 +1820,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, ...@@ -1820,7 +1820,9 @@ bool OpTeller::Tell(const framework::ir::Node* node,
const auto input_shape = input_desc->GetShape(); const auto input_shape = input_desc->GetShape();
const auto head_number = const auto head_number =
PADDLE_GET_CONST(int, desc.GetAttr("head_number")); PADDLE_GET_CONST(int, desc.GetAttr("head_number"));
auto inputs = desc.Inputs();
bool has_bias_qk = (inputs.find("BiasQK") == inputs.end()) ? false : true;
if (has_bias_qk) {
auto* biasqk_desc = block->FindVar(desc.Input("BiasQK").front()); auto* biasqk_desc = block->FindVar(desc.Input("BiasQK").front());
const auto biasqk_shape = biasqk_desc->GetShape(); const auto biasqk_shape = biasqk_desc->GetShape();
// The BiasQK's shape requires to be // The BiasQK's shape requires to be
...@@ -1839,6 +1841,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, ...@@ -1839,6 +1841,12 @@ bool OpTeller::Tell(const framework::ir::Node* node,
<< biasqk_shape[3] << "]."; << biasqk_shape[3] << "].";
return false; return false;
} }
} else {
#if !IS_TRT_VERSION_GE(8000)
VLOG(3) << "The version of TRT must be greater than 8000";
return false;
#endif
}
} }
if (op_type == "fc") { if (op_type == "fc") {
......
...@@ -40,11 +40,6 @@ class MultiHeadMatMulV2Op : public framework::OperatorWithKernel { ...@@ -40,11 +40,6 @@ class MultiHeadMatMulV2Op : public framework::OperatorWithKernel {
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input(Bias) of MultiHeadMatMul should not be null.")); "Input(Bias) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ(
context->HasInput("BiasQK"),
true,
platform::errors::InvalidArgument(
"Input(BiasQK) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
context->HasOutput("Out"), context->HasOutput("Out"),
true, true,
...@@ -69,15 +64,6 @@ class MultiHeadMatMulV2Op : public framework::OperatorWithKernel { ...@@ -69,15 +64,6 @@ class MultiHeadMatMulV2Op : public framework::OperatorWithKernel {
"%d-D tensor now.", "%d-D tensor now.",
dim_bias_q.size())); dim_bias_q.size()));
auto dim_bias_qk = context->GetInputDim("BiasQK");
PADDLE_ENFORCE_GT(
dim_bias_qk.size(),
3,
platform::errors::InvalidArgument(
"Multihead input bias qk should be at least 4-D tensor, "
"but it's %d-D tensor now.",
dim_bias_qk.size()));
auto dim_input = context->GetInputDim("Input"); auto dim_input = context->GetInputDim("Input");
context->SetOutputDim("Out", dim_input); context->SetOutputDim("Out", dim_input);
context->ShareLoD("Input", /*->*/ "Out"); context->ShareLoD("Input", /*->*/ "Out");
...@@ -90,7 +76,8 @@ class MultiHeadMatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -90,7 +76,8 @@ class MultiHeadMatMulV2OpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Input", "The input of MultiHeadMatMul op"); AddInput("Input", "The input of MultiHeadMatMul op");
AddInput("W", "The weight input of MultiHeadMatMul op"); AddInput("W", "The weight input of MultiHeadMatMul op");
AddInput("Bias", "The bias input of MultiHeadMatMul op"); AddInput("Bias", "The bias input of MultiHeadMatMul op");
AddInput("BiasQK", "The QK bias input of MultiHeadMatMul op"); AddInput("BiasQK", "The QK bias input of MultiHeadMatMul op")
.AsDispensable();
AddOutput("Out", "The output of MultiHeadMatMul op"); AddOutput("Out", "The output of MultiHeadMatMul op");
AddAttr<bool>("transpose_Q", AddAttr<bool>("transpose_Q",
R"DOC(If true, use the transpose of `Q`. R"DOC(If true, use the transpose of `Q`.
......
...@@ -264,15 +264,12 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { ...@@ -264,15 +264,12 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
auto *input = context.Input<framework::Tensor>("Input"); auto *input = context.Input<framework::Tensor>("Input");
auto *w = context.Input<framework::Tensor>("W"); auto *w = context.Input<framework::Tensor>("W");
auto *bias = context.Input<framework::Tensor>("Bias"); auto *bias = context.Input<framework::Tensor>("Bias");
auto &bias_qk = GET_DATA_SAFELY(context.Input<framework::Tensor>("BiasQK"), auto *bias_qk = context.Input<framework::Tensor>("BiasQK");
"Input",
"BiasQK",
"MultiHeadMatMulV2");
auto *input_d = input->data<T>(); auto *input_d = input->data<T>();
auto *w_d = w->data<T>(); auto *w_d = w->data<T>();
auto *bias_d = bias->data<T>(); auto *bias_d = bias->data<T>();
auto *bias_qk_d = bias_qk.template data<T>(); auto *bias_qk_d = bias_qk ? bias_qk->data<T>() : nullptr;
T scale = static_cast<T>(context.Attr<float>("alpha")); T scale = static_cast<T>(context.Attr<float>("alpha"));
int head_number = context.Attr<int>("head_number"); int head_number = context.Attr<int>("head_number");
...@@ -288,7 +285,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { ...@@ -288,7 +285,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
int hidden = input_dims[2]; int hidden = input_dims[2];
Tensor temp_bias_tensor; Tensor temp_bias_tensor;
// if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted // if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted
if (bias_qk.numel() == (batch * seq_len)) { if (bias_qk && bias_qk->numel() == (batch * seq_len)) {
temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len}); temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
auto *temp_qk_bias = temp_bias_tensor.mutable_data<T>(context.GetPlace()); auto *temp_qk_bias = temp_bias_tensor.mutable_data<T>(context.GetPlace());
int grid = batch * head_number * seq_len; int grid = batch * head_number * seq_len;
...@@ -297,6 +294,17 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { ...@@ -297,6 +294,17 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
bias_qk_d, temp_qk_bias, seq_len, head_number); bias_qk_d, temp_qk_bias, seq_len, head_number);
bias_qk_d = static_cast<const T *>(temp_qk_bias); bias_qk_d = static_cast<const T *>(temp_qk_bias);
} }
if (!bias_qk) {
int size = batch * head_number * seq_len * seq_len;
temp_bias_tensor.Resize({size});
auto *temp_qk_bias = temp_bias_tensor.mutable_data<T>(context.GetPlace());
#ifdef PADDLE_WITH_HIP
hipMemset(temp_qk_bias, 0, sizeof(float) * size);
#else
cudaMemset(temp_qk_bias, 0, sizeof(float) * size);
#endif
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
int all_head_size = w_dims[2]; int all_head_size = w_dims[2];
int head_size = all_head_size / head_number; int head_size = all_head_size / head_number;
......
...@@ -848,5 +848,271 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest): ...@@ -848,5 +848,271 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest):
yield program_config yield program_config
class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def generate_input1():
return np.zeros((1, 256, 768), dtype=np.float32)
def generate_weight1():
return np.random.rand(768, 2304).astype(np.float32)
def generate_weight2():
return np.random.rand(2304).astype(np.float32)
ops_config = [{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["input_data1"],
"Y": ["matmul1_weight"]
},
"op_outputs": {
"Out": ["matmul1_output"]
},
"op_attrs": {
"trans_x": False,
"trans_y": False
}
}, {
"op_type": "elementwise_add",
"op_inputs": {
"X": ["matmul1_output"],
"Y": ["elementwise_add1_weight"]
},
"op_outputs": {
"Out": ["elementwise_add1_output"]
},
"op_attrs": {
"Scale_out": 1.0,
"Scale_x": 1.0,
"Scale_y": 1.0,
"axis": 2
}
}, {
"op_type": "reshape2",
"op_inputs": {
"X": ["elementwise_add1_output"],
},
"op_outputs": {
"Out": ["reshape1_output"],
"XShape": ["reshape1_output_xshape"]
},
"op_attrs": {
"shape": [-1, 256, 3, 12, 64]
}
}, {
"op_type": "transpose2",
"op_inputs": {
"X": ["reshape1_output"]
},
"op_outputs": {
"Out": ["transpose1_output"],
"XShape": ["transpose1_output_xshape"]
},
"op_attrs": {
"axis": [2, 0, 3, 1, 4],
"data_format": "AnyLayout"
}
}, {
"op_type": "slice",
"op_inputs": {
"Input": ["transpose1_output"],
},
"op_outputs": {
"Out": ["slice1_output"]
},
"op_attrs": {
"axes": [0],
"starts": [0],
"ends": [1],
"decrease_axis": [0],
"infer_flags": [1]
}
}, {
"op_type": "slice",
"op_inputs": {
"Input": ["transpose1_output"],
},
"op_outputs": {
"Out": ["slice2_output"]
},
"op_attrs": {
"axes": [0],
"starts": [1],
"ends": [2],
"decrease_axis": [0],
"infer_flags": [1]
}
}, {
"op_type": "slice",
"op_inputs": {
"Input": ["transpose1_output"],
},
"op_outputs": {
"Out": ["slice3_output"]
},
"op_attrs": {
"axes": [0],
"starts": [2],
"ends": [3],
"decrease_axis": [0],
"infer_flags": [1]
}
}, {
"op_type": "transpose2",
"op_inputs": {
"X": ["slice2_output"]
},
"op_outputs": {
"Out": ["transpose2_output"],
},
"op_attrs": {
"axis": [0, 1, 3, 2],
"data_format": "AnyLayout"
}
}, {
"op_type": "matmul_v2",
"op_inputs": {
"X": ["slice1_output"],
"Y": ["transpose2_output"]
},
"op_outputs": {
"Out": ["matmul2_output"]
},
"op_attrs": {
"trans_x": False,
"trans_y": False
}
}, {
"op_type": "scale",
"op_inputs": {
"X": ["matmul2_output"],
},
"op_outputs": {
"Out": ["scale_output"]
},
"op_attrs": {
"scale": 0.125,
"bias": 0.0,
"bias_after_scale": True
}
}, {
"op_type": "softmax",
"op_inputs": {
"X": ["scale_output"]
},
"op_outputs": {
"Out": ["softmax_output"]
},
"op_attrs": {
"axis": -1,
"data_format": "AnyLayout"
}
}, {
"op_type": "matmul_v2",
"op_inputs": {
"X": ["softmax_output"],
"Y": ["slice3_output"]
},
"op_outputs": {
"Out": ["matmul3_output"]
},
"op_attrs": {
"trans_x": False,
"trans_y": False
}
}, {
"op_type": "transpose2",
"op_inputs": {
"X": ["matmul3_output"]
},
"op_outputs": {
"Out": ["transpose3_output"],
"XShape": ["transpose3_output_xshape"]
},
"op_attrs": {
"axis": [0, 2, 1, 3],
"data_format": "AnyLayout"
}
}, {
"op_type": "reshape2",
"op_inputs": {
"X": ["transpose3_output"]
},
"op_outputs": {
"Out": ["reshape2_output"],
"XShape": ["reshape2_output_xshape"]
},
"op_attrs": {
"shape": [-1, 256, 768]
}
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"matmul1_weight":
TensorConfig(data_gen=partial(generate_weight1)),
"elementwise_add1_weight":
TensorConfig(data_gen=partial(generate_weight2))
},
inputs={
"input_data1": TensorConfig(data_gen=partial(generate_input1))
},
outputs=["reshape2_output"])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
# The last dim of input1 and input2 should be static.
self.dynamic_shape.min_input_shape = {
"input_data1": [1, 8, 768],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [16, 512, 768],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [1, 197, 768],
}
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
def generate_trt_nodes_num():
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8000:
return 0, 3
return 1, 2
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.workspace_size = 2013265920
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(), (1e-3,
1e-3)
def add_skip_trt_case(self):
pass
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册