未验证 提交 2810dfea 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle-Inference] new general transformer inference support (#43077)

* new general transformer inference support
上级 0cb9dae5
......@@ -107,6 +107,9 @@ target_link_libraries(generate_pass pass_desc_proto)
if(WITH_TENSORRT)
pass_library(trt_map_matmul_to_mul_pass inference)
pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(trt_multihead_matmul_fuse_pass inference)
pass_library(trt_skip_layernorm_fuse_pass inference)
pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(preln_skip_layernorm_fuse_pass inference)
pass_library(set_transformer_input_convert_pass inference)
......
......@@ -430,13 +430,15 @@ void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
bool enable_int8 = Get<bool>("enable_int8");
bool use_oss = Get<bool>("use_oss");
bool use_varseqlen = Get<bool>("use_varseqlen");
bool with_interleaved = Get<bool>("with_interleaved");
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
if (!(enable_int8 && use_oss && with_interleaved && with_dynamic_shape)) {
if (!(enable_int8 && use_varseqlen && with_interleaved &&
with_dynamic_shape)) {
VLOG(4) << "preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, "
"enable_int8, "
"use_oss, with_interleaved, with_dynamic_shape. Stop this pass, "
"use_varseqlen, with_interleaved, with_dynamic_shape. Stop this "
"pass, "
"please reconfig.";
return;
}
......
......@@ -109,12 +109,13 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("preln_skip_layernorm_fuse", graph);
bool enable_int8 = Get<bool>("enable_int8");
bool use_oss = Get<bool>("use_oss");
bool use_varseqlen = Get<bool>("use_varseqlen");
bool with_interleaved = Get<bool>("with_interleaved");
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
if (!(enable_int8 && use_oss && with_interleaved && with_dynamic_shape)) {
if (!(enable_int8 && use_varseqlen && with_interleaved &&
with_dynamic_shape)) {
VLOG(4) << "preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, "
"use_oss, "
"use_varseqlen, "
"with_interleaved, with_dynamic_shape. Stop this pass, please "
"reconfig. ";
return;
......
......@@ -22,6 +22,19 @@ namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
void EmbEltwiseLayernorm::operator()() {
// Create nodes for fused_embedding_eltwise_layernorm.
auto* emb_elt_layernorm_op =
pattern->NewNode(emb_elt_layernorm_op_repr())
->assert_is_op("fused_embedding_eltwise_layernorm");
auto* emb_elt_layernorm_out =
pattern->NewNode(emb_elt_layernorm_out_repr())
->assert_is_op_output("fused_embedding_eltwise_layernorm", "Out");
// Add links for fused_embedding_eltwise_layernorm op.
emb_elt_layernorm_op->LinksTo({emb_elt_layernorm_out});
}
void SkipLayernorm::operator()() {
// Create nodes for skip_layernorm.
auto* skip_layernorm_x = pattern->NewNode(skip_layernorm_x_repr())
......@@ -59,16 +72,12 @@ void Fc::operator()() {
auto* fc_input =
pattern->NewNode(fc_input_repr())->assert_is_op_input("fc", "Input");
auto* fc_op = pattern->NewNode(fc_op_repr())->assert_is_op("fc");
auto* fc_out =
pattern->NewNode(fc_out_repr())->assert_is_op_output("fc", "Out");
// Add links for fc op.
fc_op->LinksFrom({fc_input}).LinksTo({fc_out});
fc_op->LinksFrom({fc_input});
}
void Activation::operator()() {
// Create nodes for activation.
std::unordered_set<std::string> activation_ops{"relu", "sigmoid", "tanh"};
std::unordered_set<std::string> activation_ops{"relu", "sigmoid", "gelu"};
auto* activation_input = pattern->NewNode(activation_input_repr())
->assert_is_ops_input(activation_ops);
auto* activation_op =
......@@ -82,6 +91,18 @@ void Activation::operator()() {
} // namespace patterns
void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
bool use_varseqlen = Get<bool>("use_varseqlen");
std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if (use_varseqlen && pos_id != "" && mask_id != "" &&
graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass)) {
VLOG(3) << "start varseqlen remove_padding_recover_padding_pass";
} else {
return;
}
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init(name_scope_, graph);
......@@ -91,14 +112,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
// Create an remove_padding op node
auto insert_remove_padding_op = [&](Node* input_node, Node* op_node) {
// create op, var in graph
OpDesc remove_padding;
OpDesc remove_padding(op_node->Op()->Block());
std::string remove_padding_out_name =
input_node->Name() + ".remove_padding";
VarDesc remove_padding_out(remove_padding_out_name);
remove_padding_out.SetDataType(input_node->Var()->GetDataType());
remove_padding_out.SetShape(input_node->Var()->GetShape());
remove_padding_out.SetPersistable(false);
auto* remove_padding_out =
op_node->Op()->Block()->Var(remove_padding_out_name);
remove_padding_out->SetDataType(input_node->Var()->GetDataType());
remove_padding_out->SetShape(input_node->Var()->GetShape());
remove_padding_out->SetPersistable(false);
// remove_padding_op
remove_padding.SetType("remove_padding");
......@@ -110,7 +131,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
remove_padding.SetOutput("Out", {remove_padding_out_name});
auto remove_padding_op_node = graph->CreateOpNode(&remove_padding);
auto remove_padding_out_node = graph->CreateVarNode(&remove_padding_out);
auto remove_padding_out_node = graph->CreateVarNode(remove_padding_out);
// replace link
for (size_t i = 0; i < input_node->outputs.size(); ++i) {
......@@ -145,13 +166,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
// create an remove_padding op node
auto insert_recover_padding_op = [&](Node* op_node, Node* out_node) {
// create op, var in graph
OpDesc recover_padding;
OpDesc recover_padding(op_node->Op()->Block());
std::string recover_padding_input_name =
out_node->Name() + ".recover_padding";
VarDesc recover_padding_input(recover_padding_input_name);
recover_padding_input.SetDataType(out_node->Var()->GetDataType());
recover_padding_input.SetShape(out_node->Var()->GetShape());
recover_padding_input.SetPersistable(false);
auto* recover_padding_input =
op_node->Op()->Block()->Var(recover_padding_input_name);
recover_padding_input->SetDataType(out_node->Var()->GetDataType());
recover_padding_input->SetShape(out_node->Var()->GetShape());
recover_padding_input->SetPersistable(false);
// recover_padding_op
recover_padding.SetType("recover_padding");
......@@ -164,7 +186,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
auto recover_padding_op_node = graph->CreateOpNode(&recover_padding);
auto recover_padding_input_node =
graph->CreateVarNode(&recover_padding_input);
graph->CreateVarNode(recover_padding_input);
// replace link
for (size_t i = 0; i < op_node->outputs.size(); ++i) {
......@@ -195,39 +217,36 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
op_node->Op()->RenameOutput(out_node->Name(), recover_padding_input_name);
};
GraphPatternDetector gpd1;
patterns::SkipLayernorm skip_layernorm(gpd1.mutable_pattern(),
"remove_padding_recover_padding_pass");
skip_layernorm();
bool check_flag = true;
auto handler1 = [&](const GraphPatternDetector::subgraph_t& subgraph,
GraphPatternDetector gpd0;
patterns::EmbEltwiseLayernorm fused_embedding_eltwise_layernorm(
gpd0.mutable_pattern(), "remove_padding_recover_padding_pass");
fused_embedding_eltwise_layernorm();
auto handler0 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"skip_layernorm";
"fused_embedding_eltwise_layernorm";
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_x, skip_layernorm_x,
skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_y, skip_layernorm_y,
skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_op, skip_layernorm_op,
skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_out, skip_layernorm_out,
skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(emb_elt_layernorm_op, emb_elt_layernorm_op,
fused_embedding_eltwise_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(emb_elt_layernorm_out, emb_elt_layernorm_out,
fused_embedding_eltwise_layernorm);
insert_remove_padding_op(skip_layernorm_x, skip_layernorm_op);
insert_remove_padding_op(skip_layernorm_y, skip_layernorm_op);
insert_recover_padding_op(skip_layernorm_op, skip_layernorm_out);
insert_recover_padding_op(emb_elt_layernorm_op, emb_elt_layernorm_out);
found_subgraph_count++;
};
gpd1(graph, handler1);
gpd0(graph, handler0);
GraphPatternDetector gpd2;
GraphPatternDetector gpd1;
patterns::MultiheadMatmul multihead_matmul(
gpd2.mutable_pattern(), "remove_padding_recover_padding_pass");
gpd1.mutable_pattern(), "remove_padding_recover_padding_pass");
multihead_matmul();
auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph,
std::vector<int64_t> multihead_matmul_input_shape;
auto handler1 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"multihead_matmul";
......@@ -239,11 +258,57 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_out, multihead_matmul_out,
multihead_matmul);
multihead_matmul_input_shape = multihead_matmul_input->Var()->GetShape();
insert_remove_padding_op(multihead_matmul_input, multihead_matmul_op);
insert_recover_padding_op(multihead_matmul_op, multihead_matmul_out);
found_subgraph_count++;
};
gpd1(graph, handler1);
GraphPatternDetector gpd2;
patterns::SkipLayernorm skip_layernorm(gpd2.mutable_pattern(),
"remove_padding_recover_padding_pass");
skip_layernorm();
auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"skip_layernorm";
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_x, skip_layernorm_x,
skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_y, skip_layernorm_y,
skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_op, skip_layernorm_op,
skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_out, skip_layernorm_out,
skip_layernorm);
std::vector<int64_t> skip_layernorm_x_shape =
skip_layernorm_x->Var()->GetShape();
if (skip_layernorm_x_shape.size() != multihead_matmul_input_shape.size()) {
check_flag = false;
VLOG(3) << "Transformer model remove_padding shape check failed, return "
"remove_padding pass.";
return;
}
for (size_t i = 0; i < skip_layernorm_x_shape.size(); ++i) {
if (skip_layernorm_x_shape[i] != multihead_matmul_input_shape[i]) {
check_flag = false;
}
}
if (!check_flag) {
VLOG(3) << "Transformer model remove_padding shape check failed, return "
"remove_padding pass.";
return;
}
insert_remove_padding_op(skip_layernorm_x, skip_layernorm_op);
insert_remove_padding_op(skip_layernorm_y, skip_layernorm_op);
insert_recover_padding_op(skip_layernorm_op, skip_layernorm_out);
found_subgraph_count++;
};
gpd2(graph, handler2);
GraphPatternDetector gpd3;
......@@ -257,11 +322,39 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(fc_input, fc_input, fc);
GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc_op, fc);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fc);
insert_remove_padding_op(fc_input, fc_op);
insert_recover_padding_op(fc_op, fc_out);
std::vector<int64_t> fc_input_shape = fc_input->Var()->GetShape();
if ((fc_input_shape.size() != multihead_matmul_input_shape.size()) ||
(fc_input_shape.size() != 3)) {
check_flag = false;
VLOG(3) << "Transformer model remove_padding shape check failed, return "
"remove_padding pass.";
return;
}
if (fc_input_shape[0] != multihead_matmul_input_shape[0]) {
check_flag = false;
}
if (fc_input_shape[1] != multihead_matmul_input_shape[1]) {
check_flag = false;
}
if ((fc_input_shape[2] != multihead_matmul_input_shape[2]) &&
(fc_input_shape[2] != 4 * multihead_matmul_input_shape[2])) {
check_flag = false;
}
if (BOOST_GET_CONST(int, fc_op->Op()->GetAttr("in_num_col_dims")) != 2) {
check_flag = false;
}
if (!check_flag) {
VLOG(3) << "Transformer model remove_padding shape check failed, return "
"remove_padding pass.";
return;
}
fc_op->Op()->RemoveAttr("in_num_col_dims");
fc_op->Op()->SetAttr("in_num_col_dims", 1);
insert_remove_padding_op(fc_input, fc_op);
insert_recover_padding_op(fc_op, fc_op->outputs[0]);
found_subgraph_count++;
};
gpd3(graph, handler3);
......@@ -280,6 +373,31 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(activation_op, activation_op, activation);
GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, activation);
std::vector<int64_t> activation_input_shape =
activation_input->Var()->GetShape();
if ((activation_input_shape.size() !=
multihead_matmul_input_shape.size()) ||
(activation_input_shape.size() != 3)) {
check_flag = false;
VLOG(3) << "Transformer model remove_padding shape check failed, return "
"remove_padding pass.";
return;
}
if (activation_input_shape[0] != multihead_matmul_input_shape[0]) {
check_flag = false;
}
if (activation_input_shape[1] != multihead_matmul_input_shape[1]) {
check_flag = false;
}
if ((activation_input_shape[2] != multihead_matmul_input_shape[2]) &&
(activation_input_shape[2] != 4 * multihead_matmul_input_shape[2])) {
check_flag = false;
}
if (!check_flag) {
VLOG(3) << "Transformer model remove_padding shape check failed, return "
"remove_padding pass.";
return;
}
insert_remove_padding_op(activation_input, activation_op);
insert_recover_padding_op(activation_op, activation_out);
......
......@@ -32,6 +32,14 @@ namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct EmbEltwiseLayernorm : public PatternBase {
EmbEltwiseLayernorm(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "emb_elt_layernorm") {}
void operator()();
PATTERN_DECL_NODE(emb_elt_layernorm_op);
PATTERN_DECL_NODE(emb_elt_layernorm_out);
};
struct SkipLayernorm : public PatternBase {
SkipLayernorm(PDPattern *pattern, const std::string &name_scope)
......
......@@ -21,129 +21,134 @@
namespace paddle {
namespace framework {
namespace ir {
SetTransformerInputConvertPass::SetTransformerInputConvertPass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.End();
}
namespace patterns {
void SetTransformerInputConvert::operator()() {
void SetTransformerInputConvert::operator()(const std::string &pos_id) {
std::unordered_set<std::string> lookup_table_ops{"lookup_table",
"lookup_table_v2"};
// Create nodes for lookup_table1 op.
auto *lookup_table1_x = pattern->NewNode(lookup_table1_x_repr())
->assert_is_ops_input(lookup_table_ops, "Ids");
auto *lookup_table1_w = pattern->NewNode(lookup_table1_w_repr())
->assert_is_ops_input(lookup_table_ops, "W");
auto *lookup_table1_op =
pattern->NewNode(lookup_table1_repr())->assert_is_ops(lookup_table_ops);
auto *lookup_table1_out = pattern->NewNode(lookup_table1_out_repr())
->assert_is_ops_output(lookup_table_ops)
->AsIntermediate()
->assert_is_op_input("elementwise_add", "X");
// Create nodes for lookup_table2 op.
auto *lookup_table2_x = pattern->NewNode(lookup_table2_x_repr())
->assert_is_ops_input(lookup_table_ops, "Ids");
auto *lookup_table2_w = pattern->NewNode(lookup_table2_w_repr())
->assert_is_ops_input(lookup_table_ops, "W");
auto *lookup_table2_op =
pattern->NewNode(lookup_table2_repr())->assert_is_ops(lookup_table_ops);
auto *lookup_table2_out = pattern->NewNode(lookup_table2_out_repr())
->assert_is_ops_output(lookup_table_ops)
->AsIntermediate()
->assert_is_op_input("elementwise_add", "Y");
// Create nodes for elementwise_add op.
auto *elementwise_op =
pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add");
auto *elementwise_out = pattern->NewNode(elementwise_out_repr())
->AsOutput()
->assert_is_only_output_of_op("elementwise_add");
// Create nodes for lookup_table.
auto *lookup_table_id =
pattern->NewNode(lookup_table_id_repr())
->assert_is_ops_input(lookup_table_ops, "Ids")
->assert_more([&](Node *node) { return node->Name() == pos_id; });
auto *lookup_table_op =
pattern->NewNode(lookup_table_repr())->assert_is_ops(lookup_table_ops);
// links nodes.
lookup_table1_op->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out});
lookup_table2_op->LinksFrom({lookup_table2_x, lookup_table2_w})
.LinksTo({lookup_table2_out});
elementwise_op->LinksFrom({lookup_table1_out, lookup_table2_out})
.LinksTo({elementwise_out});
lookup_table_op->LinksFrom({lookup_table_id});
}
void MultiheadMatmulOP::operator()() {
// Create nodes for multihead_matmul op.
auto *multihead_matmul = pattern->NewNode(multihead_matmul_repr())
->assert_is_op("multihead_matmul");
auto *multihead_matmul_out =
pattern->NewNode(multihead_matmul_out_repr())
->assert_is_op_output("multihead_matmul", "Out");
// links nodes.
multihead_matmul_out->LinksFrom({multihead_matmul});
}
} // namespace patterns
void SetTransformerInputConvertPass::ApplyImpl(ir::Graph *graph) const {
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
if (!(graph->Has(framework::ir::kMultiheadMatmulPass) && with_dynamic_shape &&
(pos_id != ""))) {
VLOG(3) << "Transformer model need MultiheadMatmul, and "
"with_dynamic_shape. Stop this pass, "
"please reconfig.";
return;
}
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init(name_scope_, graph);
int found_subgraph_count = 0;
GraphPatternDetector gpd;
Node *transformer_input_convert_out0_node;
Node *transformer_input_convert_out1_node;
GraphPatternDetector gpd0;
patterns::SetTransformerInputConvert fused_pattern(
gpd.mutable_pattern(), "transformer_input_convert_pass");
fused_pattern();
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
gpd0.mutable_pattern(), "transformer_input_convert_pass");
fused_pattern(pos_id);
auto handler0 = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "transformer_input_convert_pass in op compat failed.";
return;
}
VLOG(3) << "transformer_input_convert_pass for pos_id, max_seqlen";
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_x, lookup_table2_x, fused_pattern);
VLOG(3)
<< "transformer_input_convert_pass for pos_id, max_seqlen, mask_tensor";
GET_IR_NODE_FROM_SUBGRAPH(lookup_table, lookup_table, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table_id, lookup_table_id, fused_pattern);
// create op, var in graph
OpDesc new_desc;
OpDesc new_desc(lookup_table->Op()->Block());
new_desc.SetType("transformer_input_convert");
// inputs
new_desc.SetInput("X", {lookup_table2_x->Name()});
new_desc.SetInput("Input", {lookup_table_id->Name()});
// outputs
std::vector<std::string> output_0 = {"pos_id_tensor"};
std::vector<std::string> output_1 = {"max_seqlen_tensor"};
new_desc.SetOutput("PosId", output_0);
new_desc.SetOutput("MaxSeqlen", output_1);
std::string transformer_input_convert_out0_name = "pos_id_tensor";
std::string transformer_input_convert_out1_name = "max_seqlen_tensor";
VarDesc transformer_input_convert_out0(transformer_input_convert_out0_name);
VarDesc transformer_input_convert_out1(transformer_input_convert_out1_name);
transformer_input_convert_out0.SetDataType(proto::VarType::INT32);
transformer_input_convert_out1.SetDataType(proto::VarType::INT32);
transformer_input_convert_out0.SetShape({-1});
transformer_input_convert_out1.SetShape({-1});
transformer_input_convert_out0.SetPersistable(false);
transformer_input_convert_out1.SetPersistable(false);
std::string transformer_input_convert_out2_name = "mask_tensor";
std::vector<std::string> output_0 = {transformer_input_convert_out0_name};
std::vector<std::string> output_1 = {transformer_input_convert_out1_name};
std::vector<std::string> output_2 = {transformer_input_convert_out2_name};
new_desc.SetOutput("PosId", output_0);
new_desc.SetOutput("MaxSeqlen", output_1);
new_desc.SetOutput("MaskTensor", output_2);
auto *transformer_input_convert_out0 =
lookup_table->Op()->Block()->Var(transformer_input_convert_out0_name);
auto *transformer_input_convert_out1 =
lookup_table->Op()->Block()->Var(transformer_input_convert_out1_name);
auto *transformer_input_convert_out2 =
lookup_table->Op()->Block()->Var(transformer_input_convert_out2_name);
transformer_input_convert_out0->SetDataType(proto::VarType::INT32);
transformer_input_convert_out1->SetDataType(proto::VarType::INT32);
transformer_input_convert_out2->SetDataType(proto::VarType::INT32);
transformer_input_convert_out0->SetShape({-1});
transformer_input_convert_out1->SetShape({-1});
transformer_input_convert_out2->SetShape({-1});
transformer_input_convert_out0->SetPersistable(false);
transformer_input_convert_out1->SetPersistable(false);
transformer_input_convert_out2->SetPersistable(false);
auto new_op_node = graph->CreateOpNode(&new_desc);
auto transformer_input_convert_out0_node =
graph->CreateVarNode(&transformer_input_convert_out0);
graph->CreateVarNode(transformer_input_convert_out0);
auto transformer_input_convert_out1_node =
graph->CreateVarNode(&transformer_input_convert_out1);
graph->CreateVarNode(transformer_input_convert_out1);
auto transformer_input_convert_out2_node =
graph->CreateVarNode(transformer_input_convert_out2);
// needn't create variable in scope
IR_NODE_LINK_TO(lookup_table2_x, new_op_node);
IR_NODE_LINK_TO(lookup_table_id, new_op_node);
IR_NODE_LINK_TO(new_op_node, transformer_input_convert_out0_node);
IR_NODE_LINK_TO(new_op_node, transformer_input_convert_out1_node);
IR_NODE_LINK_TO(new_op_node, transformer_input_convert_out2_node);
};
gpd0(graph, handler0);
found_subgraph_count++;
GraphPatternDetector gpd1;
patterns::MultiheadMatmulOP multihead_matmul_pattern(
gpd1.mutable_pattern(), "transformer_input_convert_pass");
multihead_matmul_pattern();
auto handler1 = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
VLOG(3) << "link pos_id, max_seqlen to multihead_matmul.";
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul, multihead_matmul,
multihead_matmul_pattern);
IR_NODE_LINK_TO(transformer_input_convert_out0_node, multihead_matmul);
IR_NODE_LINK_TO(transformer_input_convert_out1_node, multihead_matmul);
};
gpd1(graph, handler1);
gpd(graph, handler);
found_subgraph_count++;
AddStatis(found_subgraph_count);
}
......@@ -153,9 +158,3 @@ void SetTransformerInputConvertPass::ApplyImpl(ir::Graph *graph) const {
REGISTER_PASS(set_transformer_input_convert_pass,
paddle::framework::ir::SetTransformerInputConvertPass);
REGISTER_PASS_CAPABILITY(set_transformer_input_convert_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("lookup_table", 1)
.LE("lookup_table_v2", 1)
.LE("elementweise_add", 1));
......@@ -33,41 +33,36 @@ namespace framework {
namespace ir {
namespace patterns {
// in_var emb in_var emb
// | | | |
// lookup_table lookup_table
// in_var emb
// | |
// lkt_var lkt_var
// \ /
// elementwise_add
// lookup_table
// |
// elt_out_var
// lkt_var
//
struct SetTransformerInputConvert : public PatternBase {
SetTransformerInputConvert(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "transformer_input_convert") {}
: PatternBase(pattern, name_scope, "transformer_input_convert_pass") {}
void operator()(const std::string &pos_id);
// declare operator node's name
PATTERN_DECL_NODE(lookup_table);
// declare variable node's name
PATTERN_DECL_NODE(lookup_table_id);
};
struct MultiheadMatmulOP : public PatternBase {
MultiheadMatmulOP(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "transformer_input_convert_pass") {}
void operator()();
// declare operator node's name
PATTERN_DECL_NODE(lookup_table1);
PATTERN_DECL_NODE(lookup_table2);
PATTERN_DECL_NODE(elementwise);
// declare variable node's name
PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table1_w);
PATTERN_DECL_NODE(lookup_table1_out);
PATTERN_DECL_NODE(lookup_table2_x);
PATTERN_DECL_NODE(lookup_table2_w);
PATTERN_DECL_NODE(lookup_table2_out);
PATTERN_DECL_NODE(elementwise_out);
PATTERN_DECL_NODE(multihead_matmul);
PATTERN_DECL_NODE(multihead_matmul_out);
};
} // namespace patterns
class SetTransformerInputConvertPass : public FusePassBase {
public:
SetTransformerInputConvertPass();
SetTransformerInputConvertPass() {}
virtual ~SetTransformerInputConvertPass() {}
protected:
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
const std::string& arg,
bool is_persist = false) {
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
PDNode* node =
pattern->NewNode(name)->assert_is_ops_input(embedding_ops, arg);
if (is_persist) return node->assert_is_persistable_var();
return node;
}
static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name,
const std::string& arg) {
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
PDNode* node = pattern->NewNode(name)
->assert_is_only_output_of_ops(embedding_ops)
->assert_is_op_input("elementwise_add", arg)
->AsIntermediate();
return node;
}
void TrtEmbedding2Eltwise1Pattern::operator()() {
auto* lookup_table1_x =
create_emb_vars(pattern, lookup_table1_x_repr(), "Ids");
auto* lookup_table2_x =
create_emb_vars(pattern, lookup_table2_x_repr(), "Ids");
auto* lookup_table1_w =
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
auto* lookup_table2_w =
create_emb_vars(pattern, lookup_table2_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
auto* feed1 = pattern->NewNode(feed1_repr())->assert_is_op("feed");
auto* feed2 = pattern->NewNode(feed2_repr())->assert_is_op("feed");
auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
auto* lookup_table2 =
pattern->NewNode(lookup_table2_repr())->assert_is_ops(embedding_ops);
auto* lookup_table1_out =
create_emb_out_vars(pattern, lookup_table1_out_repr(), "X");
auto* lookup_table2_out =
create_emb_out_vars(pattern, lookup_table2_out_repr(), "Y");
auto* eltwise_add =
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
->assert_is_op_output("elementwise_add");
feed1->LinksTo({lookup_table1_x});
lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out});
feed2->LinksTo({lookup_table2_x});
lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w})
.LinksTo({lookup_table2_out});
eltwise_add->LinksFrom({lookup_table1_out, lookup_table2_out})
.LinksTo({eltwise_add_out});
}
void TrtEmbedding1Eltwise1Pattern::operator()() {
auto* lookup_table1_x =
create_emb_vars(pattern, lookup_table1_x_repr(), "Ids");
auto* lookup_table1_w =
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
auto* feed1 = pattern->NewNode(feed1_repr())->assert_is_op("feed");
auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
auto* lookup_table1_out =
create_emb_out_vars(pattern, lookup_table1_out_repr(), "Y");
auto* eltwise_add =
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_in = pattern->NewNode(eltwise_add_in_repr())
->assert_is_op_input("elementwise_add", "X")
->assert_is_op_output("elementwise_add");
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
->assert_is_op_output("elementwise_add");
lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out});
feed1->LinksTo({lookup_table1_x});
eltwise_add->LinksFrom({lookup_table1_out, eltwise_add_in})
.LinksTo({eltwise_add_out});
}
void TrtSkipLayerNorm::operator()() {
auto* eltwise_add =
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
->assert_is_op_output("elementwise_add")
->assert_is_op_input("layer_norm", "X")
->AsIntermediate();
auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_out = pattern->NewNode(layer_norm_out_repr())
->assert_is_op_output("layer_norm", "Y")
->AsOutput();
auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
eltwise_add->LinksTo({eltwise_add_out});
layer_norm
->LinksFrom({eltwise_add_out, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo({layer_norm_out, layer_norm_mean_var, layer_norm_variance_var});
}
} // namespace patterns
int TrtEmbeddingEltwiseLayerNormFusePass::BuildFusion(
Graph* graph, const std::string& name_scope
/*const Scope* scope*/) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
bool use_varseqlen = Get<bool>("use_varseqlen");
std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
std::vector<std::vector<std::pair<Node*, Node*>>> start_pattern_in_nodes;
std::vector<Node*> start_pattern_out_node;
std::vector<std::unordered_set<Node*>> start_pattern_remove_nodes;
// Create pattern.
patterns::TrtEmbedding2Eltwise1Pattern start_pattern(pattern,
name_scope + "/start");
start_pattern();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_x, lookup_table2_x, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_w, lookup_table2_w, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2, lookup_table2, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out,
start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_out, lookup_table2_out,
start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, start_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "Pass(TrtEmbedding2Eltwise1Pattern) in op compat failed.";
return;
}
std::vector<std::pair<Node*, Node*>> ins;
ins.push_back(std::make_pair(lookup_table1_x, lookup_table1_w));
ins.push_back(std::make_pair(lookup_table2_x, lookup_table2_w));
start_pattern_in_nodes.push_back(ins);
start_pattern_out_node.push_back(eltwise_add_out);
std::unordered_set<Node*> rm_nodes;
rm_nodes.insert({lookup_table1, lookup_table2, lookup_table1_out,
lookup_table2_out, eltwise_add, eltwise_add_out});
start_pattern_remove_nodes.push_back(rm_nodes);
};
gpd(graph, handler);
std::vector<std::pair<Node*, Node*>> inner_pattern_ins;
std::vector<Node*> inner_pattern_tmp_in;
std::vector<Node*> inner_pattern_out;
std::vector<std::unordered_set<Node*>> inner_pattern_remove_nodes;
GraphPatternDetector gpd2;
auto* pattern2 = gpd2.mutable_pattern();
patterns::TrtEmbedding1Eltwise1Pattern second_pattern(pattern2,
name_scope + "/second");
second_pattern();
auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out,
second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_in, eltwise_add_in, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, second_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "Pass(TrtEmbedding1Eltwise1Pattern) in op compat failed.";
return;
}
auto in = std::make_pair(lookup_table1_x, lookup_table1_w);
inner_pattern_ins.push_back(in);
inner_pattern_tmp_in.push_back(eltwise_add_in);
inner_pattern_out.push_back(eltwise_add_out);
std::unordered_set<Node*> rm_nodes;
rm_nodes.insert(
{lookup_table1, lookup_table1_out, eltwise_add, eltwise_add_out});
inner_pattern_remove_nodes.push_back(rm_nodes);
};
gpd2(graph, handler2);
std::vector<Node*> end_pattern_elt_out;
std::vector<Node*> end_pattern_scales;
std::vector<Node*> end_pattern_biases;
std::vector<Node*> end_pattern_out;
std::vector<Node*> end_patter_layernorms;
std::vector<std::unordered_set<Node*>> end_pattern_remove_nodes;
GraphPatternDetector gpd3;
auto* pattern3 = gpd3.mutable_pattern();
patterns::TrtSkipLayerNorm skip_layernorm_pattern(pattern3,
name_scope + "/third");
skip_layernorm_pattern();
auto handler3 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance,
skip_layernorm_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "Pass(TrtSkipLayerNorm) in op compat failed.";
return;
}
end_pattern_elt_out.push_back(eltwise_add_out);
std::unordered_set<Node*> rm_nodes;
rm_nodes.insert({layer_norm, layer_norm_mean, layer_norm_variance});
end_pattern_remove_nodes.push_back(rm_nodes);
end_pattern_biases.push_back(layer_norm_bias);
end_pattern_scales.push_back(layer_norm_scale);
end_pattern_out.push_back(layer_norm_out);
end_patter_layernorms.push_back(layer_norm);
};
gpd3(graph, handler3);
if (start_pattern_in_nodes.empty() || end_pattern_elt_out.empty()) {
return 0;
}
// only reserve the subgraphs that in connected domains.
int fusion_count = 0;
// fusion_id for (i, k, js)
std::vector<std::pair<size_t, std::pair<size_t, std::vector<size_t>>>>
fusion_ids;
for (size_t i = 0; i < start_pattern_in_nodes.size(); ++i) {
Node* tmp = start_pattern_out_node[i];
Node* old_tmp = nullptr;
// get correct inner pattern node order.
std::vector<size_t> js;
while (tmp != old_tmp) {
old_tmp = tmp;
for (size_t j = 0; j < inner_pattern_tmp_in.size(); ++j) {
if (inner_pattern_tmp_in[j] == tmp) {
tmp = inner_pattern_out[j];
js.push_back(j);
break;
}
}
}
for (size_t k = 0; k < end_pattern_elt_out.size(); ++k) {
if (tmp == end_pattern_elt_out[k]) {
fusion_ids.push_back(std::make_pair(i, std::make_pair(k, js)));
break;
}
}
}
for (size_t num = 0; num < fusion_ids.size(); ++num) {
int i = fusion_ids[num].first;
int k = fusion_ids[num].second.first;
std::vector<size_t> js = fusion_ids[num].second.second;
std::vector<std::string> ids;
std::vector<std::string> embs;
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
ids.push_back(start_pattern_in_nodes[i][iter].first->Name());
embs.push_back(start_pattern_in_nodes[i][iter].second->Name());
}
for (size_t iter = 0; iter < js.size(); ++iter) {
ids.push_back(inner_pattern_ins[js[iter]].first->Name());
embs.push_back(inner_pattern_ins[js[iter]].second->Name());
}
OpDesc new_op_desc(end_patter_layernorms[0]->Op()->Block());
new_op_desc.SetType("fused_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids);
new_op_desc.SetInput("Embs", embs);
new_op_desc.SetInput("WordId", {ids[0]});
if (use_varseqlen && pos_id != "" && mask_id != "") {
new_op_desc.SetInput("PosId", {pos_id});
new_op_desc.SetInput("MaskId", {mask_id});
} else {
new_op_desc.SetInput("PosId", {ids[1]});
}
if (ids.size() > 2) {
new_op_desc.SetInput("SentId", {ids[2]});
}
new_op_desc.SetInput("WordEmbedding", {embs[0]});
new_op_desc.SetInput("PosEmbedding", {embs[1]});
if (embs.size() > 2) {
new_op_desc.SetInput("SentEmbedding", {embs[2]});
}
new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()});
new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()});
new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()});
new_op_desc.SetAttr("epsilon",
end_patter_layernorms[k]->Op()->GetAttr("epsilon"));
if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold")) {
new_op_desc.SetAttr("enable_int8", true);
new_op_desc.SetAttr(
"out_threshold",
end_patter_layernorms[k]->Op()->GetAttr("out_threshold"));
}
auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc);
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second,
embedding_eltwise_layernorm);
}
for (size_t iter = 0; iter < js.size(); ++iter) {
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second,
embedding_eltwise_layernorm);
}
IR_NODE_LINK_TO(end_pattern_biases[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(end_pattern_scales[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(embedding_eltwise_layernorm, end_pattern_out[k]);
// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes;
marked_nodes.insert(start_pattern_remove_nodes[i].begin(),
start_pattern_remove_nodes[i].end());
marked_nodes.insert(end_pattern_remove_nodes[k].begin(),
end_pattern_remove_nodes[k].end());
for (size_t iter = 0; iter < js.size(); ++iter) {
marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(),
inner_pattern_remove_nodes[js[iter]].end());
}
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
}
return fusion_count;
}
TrtEmbeddingEltwiseLayerNormFusePass::TrtEmbeddingEltwiseLayerNormFusePass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.End();
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
}
void TrtEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
if (!with_dynamic_shape) {
VLOG(3) << "trt_embedding_eltwise_layernorm_fuse_pass need: use_varseqlen, "
"with_dynamic_shape. Stop this pass, "
"please reconfig.";
return;
}
FusePassBase::Init(name_scope_, graph);
int fusion_count =
TrtEmbeddingEltwiseLayerNormFusePass::BuildFusion(graph, name_scope_);
if (fusion_count > 0) {
bool use_varseqlen = Get<bool>("use_varseqlen");
std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if ((use_varseqlen && pos_id != "" && mask_id != "") ||
(!use_varseqlen && pos_id == "" && mask_id == "")) {
VLOG(3) << "start trt_embedding_eltwise_layernorm_fuse_pass";
} else {
PADDLE_THROW(
platform::errors::Fatal("Use transformer'varseqlen need config: "
"use_varseqlen, set pos_id, set "
"mask_id. Or not use varseqlen, do not set "
"pos_id, set mask_id. Please "
"reconfig"));
}
graph->Set(kEmbEltwiseLayernormPass, new bool(true));
}
AddStatis(fusion_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(trt_embedding_eltwise_layernorm_fuse_pass,
paddle::framework::ir::TrtEmbeddingEltwiseLayerNormFusePass);
REGISTER_PASS_CAPABILITY(trt_embedding_eltwise_layernorm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("lookup_table", 1)
.LE("lookup_table_v2", 1)
.LE("elementweise_add", 1));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
// detect start pattern.
//
// in_var emb in_var emb
// | | | |
// lookup_table lookup_table
// | |
// lkt_var lkt_var
// \ /
// elementwise_add
// |
// elt_out_var
//
struct TrtEmbedding2Eltwise1Pattern : public PatternBase {
TrtEmbedding2Eltwise1Pattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "embedding2_eltwise1") {}
void operator()();
PATTERN_DECL_NODE(feed1);
PATTERN_DECL_NODE(feed2);
PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table2_x);
PATTERN_DECL_NODE(lookup_table1_w);
PATTERN_DECL_NODE(lookup_table2_w);
PATTERN_DECL_NODE(lookup_table1);
PATTERN_DECL_NODE(lookup_table2);
PATTERN_DECL_NODE(lookup_table1_out);
PATTERN_DECL_NODE(lookup_table2_out);
PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out);
};
// detect repeats inner pattern
//
// elt_out_var in_var emb
// \ | |
// \ lookup_table
// \ |
// \ lkt_var
// \ /
// elementwise_add
// |
// elt_out_var
//
struct TrtEmbedding1Eltwise1Pattern : public PatternBase {
TrtEmbedding1Eltwise1Pattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "embedding1_eltwise1") {}
void operator()();
PATTERN_DECL_NODE(feed1);
PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table1_w);
PATTERN_DECL_NODE(lookup_table1);
PATTERN_DECL_NODE(lookup_table1_out);
PATTERN_DECL_NODE(eltwise_add_in);
PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out);
};
// detect end pattern
//
// elementwise_add
// |
// elt_out_var
// scale | bias
// \ | /
// layer_norm
//
struct TrtSkipLayerNorm : public PatternBase {
TrtSkipLayerNorm(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "skip_layernorm") {}
void operator()();
PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_out);
// Delete the mean and var nodes in the graph.
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
};
} // namespace patterns
// The TrtEmbeddingEltwiseLayerNormFusePass detect the following pattern:
//
// inputs operator output
// --------------------------------------------------------------------
// (word, weights_0) lookup_table -> word_emb
// (pos, weights_1) lookup_table -> pos_emb
// (sent, weights_2) lookup_table -> sent_emb
// (word_emb, pos_emb) elementweise_add -> elementwise_out_0
// (elemtwise_out_0, sent_emb) elementweise_add -> elementwise_out_1
// (elementwise_out_1, scale, bias) layer_norm -> layer_norm_out
//
// and then convert the corresponding subgraph to:
//
// (word, pos, sent, weights_0, weights_1, weights_2,
// scale, baias) embedding_eltwise_layernorm -> layer_norm_out
//
//
// in_var emb_var in_var emb_var in_var emb_var in_var emb_var
// | | | | | | | |
// lookup_table lookup_table lookup_table ... lookup_table
// | | | |
// lkt_var lkt_var lkt_var lkt_var
// \ / | ... |
// elementwise_add | |
// \ / |
// elementwise_add |
// | |
// elt_var /
// \ /
// elementwise_add
// |
// layer_norm
class TrtEmbeddingEltwiseLayerNormFusePass : public FusePassBase {
public:
TrtEmbeddingEltwiseLayerNormFusePass();
virtual ~TrtEmbeddingEltwiseLayerNormFusePass() {}
protected:
void ApplyImpl(Graph* graph) const;
int BuildFusion(Graph* graph, const std::string& name_scope
/*const Scope* scope*/) const;
const std::string name_scope_{"trt_embedding_eltwise_layernorm_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
static void ReplaceOutputVar(Node* op, Node* old_var, Node* new_var) {
if (op->IsOp() && op->Op()) {
new_var->inputs.push_back(op);
for (size_t i = 0; i < op->outputs.size(); ++i) {
if (op->outputs[i] == old_var) {
op->outputs[i] = new_var;
op->Op()->RenameOutput(old_var->Name(), new_var->Name());
}
}
}
}
static int BuildFusion(Graph* graph, const std::string& name_scope) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
TrtMultiHeadMatmulPattern multihead_pattern(pattern, name_scope);
multihead_pattern();
// Create New OpDesc
auto fuse_creater = [&](
Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
Node* mul1_out, Node* mul2_out, Node* eltadd0_b, Node* eltadd1_b,
Node* eltadd2_b, Node* eltadd_qk_b, Node* reshape2,
Node* reshape2_qkv_out, Node* scale, Node* scale_out) {
auto scale_attr = BOOST_GET_CONST(float, scale->Op()->GetAttr("scale"));
// auto scale_bias = BOOST_GET_CONST(float, scale->Op()->GetAttr("bias"));
// bool after_scale =
// BOOST_GET_CONST(bool, scale->Op()->GetAttr("bias_after_scale"));
// create multihead
OpDesc multihead_op_desc(mul0->Op()->Block());
// create tmp tensor
VarDesc k_var_desc(*mul1_out->Var());
k_var_desc.SetName("K" + mul1_out->Name());
auto* k_var_node = graph->CreateVarNode(&k_var_desc);
VarDesc q_var_desc(*mul0_out->Var());
q_var_desc.SetName("Q" + mul0_out->Name());
auto* q_var_node = graph->CreateVarNode(&q_var_desc);
VarDesc v_var_desc(*mul2_out->Var());
v_var_desc.SetName("V" + mul2_out->Name());
auto* v_var_node = graph->CreateVarNode(&v_var_desc);
auto reshape_desc = reshape2->Op();
int head_number =
BOOST_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape")).at(2);
ReplaceOutputVar(mul0, mul0_out, q_var_node);
ReplaceOutputVar(mul1, mul1_out, k_var_node);
ReplaceOutputVar(mul2, mul2_out, v_var_node);
multihead_op_desc.SetType("multihead_matmul");
multihead_op_desc.SetInput("Q", {q_var_node->Name()});
multihead_op_desc.SetInput("K", {k_var_node->Name()});
multihead_op_desc.SetInput("V", {v_var_node->Name()});
multihead_op_desc.SetInput("BiasQ", {eltadd0_b->Name()});
multihead_op_desc.SetInput("BiasK", {eltadd1_b->Name()});
multihead_op_desc.SetInput("BiasV", {eltadd2_b->Name()});
multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()});
multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()});
multihead_op_desc.SetAttr("alpha", scale_attr);
multihead_op_desc.SetAttr("head_number", head_number);
auto* multihead = graph->CreateOpNode(&multihead_op_desc);
IR_NODE_LINK_TO(q_var_node, multihead);
IR_NODE_LINK_TO(k_var_node, multihead);
IR_NODE_LINK_TO(v_var_node, multihead);
IR_NODE_LINK_TO(eltadd0_b, multihead);
IR_NODE_LINK_TO(eltadd1_b, multihead);
IR_NODE_LINK_TO(eltadd2_b, multihead);
IR_NODE_LINK_TO(eltadd_qk_b, multihead);
IR_NODE_LINK_TO(multihead, reshape2_qkv_out);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out,
multihead_pattern);
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd0_b, eltadd0_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd0_out, eltadd0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1, eltadd1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1_b, eltadd1_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1_out, eltadd1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2, eltadd2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2_b, eltadd2_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2_out, eltadd2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qk, matmul_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qk_out, matmul_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk, eltadd_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_b, eltadd_qk_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern);
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, reshape2_0,
reshape2_qkv_out, scale, scale_out);
std::unordered_set<const Node*> marked_nodes(
{eltadd0,
eltadd1,
eltadd2,
eltadd0_out,
eltadd1_out,
eltadd2_out,
reshape2_0,
reshape2_1,
reshape2_2,
reshape2_0_out,
reshape2_1_out,
reshape2_2_out,
transpose2_0,
transpose2_1,
transpose2_2,
transpose2_0_out,
transpose2_1_out,
transpose2_2_out,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out, // dropout_qk, dropout_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
mul0_out,
mul1_out,
mul2_out,
reshape2_qkv,
scale});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
PDNode* TrtMultiHeadMatmulPattern::operator()() {
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("mul");
// First path with scale
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul");
auto* mul0_w_var = pattern->NewNode(mul0_w_repr())
->AsInput()
->assert_is_op_input("mul", "Y");
auto* mul0_out_var =
pattern->NewNode(mul0_out_repr())->assert_is_op_output("mul");
decltype(mul0) eltadd0;
decltype(mul0) eltadd0_b_var;
decltype(mul0) eltadd0_out_var;
mul0_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd0 = pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add");
eltadd0_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var =
pattern->NewNode(reshape2_0_out_repr())->assert_is_op_output("reshape2");
reshape2_0_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2");
transpose2_0_out_var->AsIntermediate()->assert_is_op_input("scale");
auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale");
auto* scale_out_var =
pattern->NewNode(scale_out_repr())->assert_is_op_output("scale");
scale_out_var->AsIntermediate()->assert_is_op_input("matmul");
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add");
eltadd_qk_out_var->AsIntermediate()->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var =
pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax");
softmax_qk_out_var->AsIntermediate()->assert_is_op_input("matmul");
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul");
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul");
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2");
reshape2_qkv_out_var->assert_is_op_input("mul");
// Second path to matmul
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("mul");
auto* mul1_w_var = pattern->NewNode(mul1_w_repr())
->AsInput()
->assert_is_op_input("mul", "Y");
auto* mul1_out_var =
pattern->NewNode(mul1_out_repr())->assert_is_op_output("mul");
decltype(mul1) eltadd1;
decltype(mul1) eltadd1_b_var;
decltype(mul1) eltadd1_out_var;
mul1_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd1 = pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add");
eltadd1_b_var = pattern->NewNode(eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd1_out_var = pattern->NewNode(eltadd1_out_repr())
->assert_is_op_output("elementwise_add");
eltadd1_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_1 =
pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2");
auto* reshape2_1_out_var =
pattern->NewNode(reshape2_1_out_repr())->assert_is_op_output("reshape2");
reshape2_1_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_1 =
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2");
transpose2_1_out_var->AsIntermediate()->assert_is_op_input(
"matmul"); // link to matmul qk
// Third path to matmul
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("mul");
auto* mul2_w_var = pattern->NewNode(mul2_w_repr())
->AsInput()
->assert_is_op_input("mul", "Y");
auto* mul2_out_var =
pattern->NewNode(mul2_out_repr())->assert_is_op_output("mul");
decltype(mul2) eltadd2;
decltype(mul2) eltadd2_b_var;
decltype(mul2) eltadd2_out_var;
mul2_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd2 = pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add");
eltadd2_b_var = pattern->NewNode(eltadd2_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd2_out_var = pattern->NewNode(eltadd2_out_repr())
->assert_is_op_output("elementwise_add");
eltadd2_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_2 =
pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2");
auto* reshape2_2_out_var =
pattern->NewNode(reshape2_2_out_repr())->assert_is_op_output("reshape2");
reshape2_2_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_2 =
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2");
transpose2_2_out_var->AsIntermediate()->assert_is_op_input(
"matmul"); // link to matmul qkv
// Q path
mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var});
eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
scale->LinksFrom({transpose2_0_out_var}).LinksTo({scale_out_var});
// K path
mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var});
eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var});
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var});
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
// compute q*k
matmul_qk->LinksFrom({scale_out_var, transpose2_1_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
// V path
mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var});
eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var});
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
// compute q*k*v
matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
return transpose2_2_out_var;
}
PDNode* TrtMultiHeadMatmulV3Pattern::operator()() {
std::unordered_set<std::string> matmul_ops{"matmul", "matmul_v2"};
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_ops_input(matmul_ops);
// First path with scale
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_ops(matmul_ops);
auto* mul0_w_var = pattern->NewNode(mul0_w_repr())
->AsInput()
->assert_is_ops_input(matmul_ops, "Y");
auto* mul0_out_var =
pattern->NewNode(mul0_out_repr())->assert_is_ops_output(matmul_ops);
decltype(mul0) eltadd0;
decltype(mul0) eltadd0_b_var;
decltype(mul0) eltadd0_out_var;
mul0_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd0 = pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add");
eltadd0_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var =
pattern->NewNode(reshape2_0_out_repr())->assert_is_op_output("reshape2");
reshape2_0_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2");
transpose2_0_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops, "X");
auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_ops(matmul_ops);
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add");
eltadd_qk_out_var->AsIntermediate()->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var =
pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax");
softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops);
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops);
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2");
reshape2_qkv_out_var->assert_is_ops_input(matmul_ops);
// Second path to matmul
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_ops(matmul_ops);
auto* mul1_w_var = pattern->NewNode(mul1_w_repr())
->AsInput()
->assert_is_ops_input(matmul_ops, "Y");
auto* mul1_out_var =
pattern->NewNode(mul1_out_repr())->assert_is_ops_output(matmul_ops);
decltype(mul1) eltadd1;
decltype(mul1) eltadd1_b_var;
decltype(mul1) eltadd1_out_var;
mul1_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd1 = pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add");
eltadd1_b_var = pattern->NewNode(eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd1_out_var = pattern->NewNode(eltadd1_out_repr())
->assert_is_op_output("elementwise_add");
eltadd1_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_1 =
pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2");
auto* reshape2_1_out_var =
pattern->NewNode(reshape2_1_out_repr())->assert_is_op_output("reshape2");
reshape2_1_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_1 =
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2");
transpose2_1_out_var->AsIntermediate()->assert_is_ops_input(
matmul_ops, "Y"); // link to matmul qk
// Third path to matmul
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_ops(matmul_ops);
auto* mul2_w_var = pattern->NewNode(mul2_w_repr())
->AsInput()
->assert_is_ops_input(matmul_ops, "Y");
auto* mul2_out_var =
pattern->NewNode(mul2_out_repr())->assert_is_ops_output(matmul_ops);
decltype(mul2) eltadd2;
decltype(mul2) eltadd2_b_var;
decltype(mul2) eltadd2_out_var;
mul2_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd2 = pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add");
eltadd2_b_var = pattern->NewNode(eltadd2_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd2_out_var = pattern->NewNode(eltadd2_out_repr())
->assert_is_op_output("elementwise_add");
eltadd2_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_2 =
pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2");
auto* reshape2_2_out_var =
pattern->NewNode(reshape2_2_out_repr())->assert_is_op_output("reshape2");
reshape2_2_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_2 =
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2");
transpose2_2_out_var->AsIntermediate()->assert_is_ops_input(
matmul_ops); // link to matmul qkv
// Q path
mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var});
eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
// K path
mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var});
eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var});
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var});
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
// compute q*k
matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
// V path
mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var});
eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var});
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
// compute q*k*v
matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
return transpose2_2_out_var;
}
} // namespace patterns
void TrtMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = patterns::BuildFusion(graph, name_scope_);
AddStatis(fusion_count);
}
TrtMultiHeadMatmulV2FusePass::TrtMultiHeadMatmulV2FusePass() {
AddOpCompat(OpCompat("mul"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumEQ(2)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
// in bias, shape is (B, S, N*H),
// in biasqk, shape is (B, H, S, S)
.IsTensor()
.End()
.AddInput("Y")
// in bias, shape is (N*H)
// in biasqk, shape is (B, H, S, S)
.IsTensor()
.End()
// in bias, shape is (B, S, N*H)
// in biasqk, shape is (B, H, S, S)
.AddOutput("Out")
.IsTensor()
.End()
// in bias, it equal to 2
// in biasqk, it equal to -1 or 0
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
// -->: (B, S, H, N) -> (B, H, S, N)
// <--: (B, H, S, N) -> (B, S, H, N)
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsType<float>() // copy to new op. so unconstrained.
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale") // bias is 0, so unconstrained.
.IsType<bool>()
.End();
// QK (B, H, S, N)*(B, H, S, N) -> (B, H, S, S)
// QKV (B, H, S, S)*(B, H, S, N) -> (B, H, S, N)
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumEQ(1.0f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y") // QK(true) QKV(false)
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
.End();
}
int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
const std::string& name_scope,
Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::TrtMultiHeadMatmulPattern multihead_pattern(pattern, name_scope);
multihead_pattern();
// Create New OpDesc
auto fuse_creater = [&](
Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w,
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b,
Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out,
Node* softmax_qk, Node* eltadd0, Node* eltadd1, Node* eltadd2,
Node* matmul_qk, Node* reshape2_qkv) {
auto scale_attr = BOOST_GET_CONST(float, scale->Op()->GetAttr("scale"));
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
// bias (B * S * 3 * N * H) + bias (3 * N * H)
// Transpose (B * S * 3 * N * H) -> (3 * B * N * S * H)
auto* wq_tensor = scope->FindVar(mul0_w->Name())->GetMutable<LoDTensor>();
auto* wk_tensor = scope->FindVar(mul1_w->Name())->GetMutable<LoDTensor>();
auto* wv_tensor = scope->FindVar(mul2_w->Name())->GetMutable<LoDTensor>();
auto* bq_tensor =
scope->FindVar(eltadd0_b->Name())->GetMutable<LoDTensor>();
auto* bk_tensor =
scope->FindVar(eltadd1_b->Name())->GetMutable<LoDTensor>();
auto* bv_tensor =
scope->FindVar(eltadd2_b->Name())->GetMutable<LoDTensor>();
auto* wq_data = wq_tensor->mutable_data<float>(platform::CPUPlace());
auto* wk_data = wk_tensor->mutable_data<float>(platform::CPUPlace());
auto* wv_data = wv_tensor->mutable_data<float>(platform::CPUPlace());
auto* bq_data = bq_tensor->mutable_data<float>(platform::CPUPlace());
auto* bk_data = bk_tensor->mutable_data<float>(platform::CPUPlace());
auto* bv_data = bv_tensor->mutable_data<float>(platform::CPUPlace());
auto combined_w_dims =
phi::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
auto combined_bias_dims = phi::make_ddim({3, bq_tensor->dims()[0]});
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto* combined_w_desc = mul0_w->Var();
combined_w_desc->SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
combined_w_desc->SetPersistable(true);
auto* combined_bias_desc = eltadd0_b->Var();
combined_bias_desc->SetShape({3, bq_tensor->dims()[0]});
combined_bias_desc->SetPersistable(true);
framework::LoDTensor tmp_combined_w_tensor;
tmp_combined_w_tensor.Resize(combined_w_dims);
auto* tmp_combined_w_data =
tmp_combined_w_tensor.mutable_data<float>(platform::CPUPlace());
std::vector<float*> w_vec = {wq_data, wk_data, wv_data};
int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2];
// Combine the three fc weights together.
for (int i = 0; i < dims_h; i++) {
for (int j = 0; j < 3; j++) {
for (int k = 0; k < dims_w; k++) {
int out_index = i * (3 * dims_w) + j * dims_w + k;
int in_index = i * dims_w + k;
tmp_combined_w_data[out_index] = w_vec[j][in_index];
}
}
}
wq_tensor->Resize(combined_w_dims);
auto* new_combined_w_data =
wq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_w_data, tmp_combined_w_data,
sizeof(float) * wq_tensor->numel());
scope->EraseVars({mul1_w->Name(), mul2_w->Name()});
framework::LoDTensor tmp_combined_bias_tensor;
tmp_combined_bias_tensor.Resize(combined_bias_dims);
auto* tmp_combined_bias_data =
tmp_combined_bias_tensor.mutable_data<float>(platform::CPUPlace());
size_t bias_size = bq_tensor->numel();
memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + bias_size, bk_data,
sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + 2 * bias_size, bv_data,
sizeof(float) * bias_size);
bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data =
bq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_bias_data, tmp_combined_bias_data,
sizeof(float) * bq_tensor->numel());
scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});
auto reshape_desc = reshape2->Op();
int head_number =
BOOST_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape")).at(2);
OpDesc multihead_op_desc(mul0->Op()->Block());
multihead_op_desc.SetType("multihead_matmul");
multihead_op_desc.SetInput("Input", {input0->Name()});
multihead_op_desc.SetInput("W", {mul0_w->Name()});
multihead_op_desc.SetInput("Bias", {eltadd0_b->Name()});
multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()});
multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()});
multihead_op_desc.SetAttr("alpha", scale_attr);
multihead_op_desc.SetAttr("head_number", head_number);
auto* mul0_op_desc = mul0->Op();
// all mul op has same input.
if (multihead_op_desc.HasAttr("Input_scale")) {
multihead_op_desc.SetAttr("Input_scale",
mul0_op_desc->GetAttr("Input_scale"));
}
auto* add0_op_desc = eltadd0->Op();
auto* add1_op_desc = eltadd1->Op();
auto* add2_op_desc = eltadd2->Op();
if (add0_op_desc->HasAttr("out_threshold")) {
auto out_scale0 =
BOOST_GET_CONST(float, add0_op_desc->GetAttr("out_threshold"));
auto out_scale1 =
BOOST_GET_CONST(float, add1_op_desc->GetAttr("out_threshold"));
auto out_scale2 =
BOOST_GET_CONST(float, add2_op_desc->GetAttr("out_threshold"));
auto out_scale_max = std::max(out_scale0, out_scale1);
out_scale_max = std::max(out_scale_max, out_scale2);
multihead_op_desc.SetAttr("fc_out_threshold", out_scale_max);
}
auto* softmax_qk_op_desc = softmax_qk->Op();
auto* matmul_qk_op_desc = matmul_qk->Op();
if (matmul_qk_op_desc->HasAttr("Input_scale")) {
multihead_op_desc.SetAttr("qkv2context_plugin_int8", true);
if (softmax_qk_op_desc->HasAttr("out_threshold")) {
auto qkv_plugin_scale = BOOST_GET_CONST(
float, softmax_qk_op_desc->GetAttr("out_threshold"));
multihead_op_desc.SetAttr("dp_probs", qkv_plugin_scale);
}
}
if (reshape2_qkv->Op()->HasAttr("out_threshold")) {
multihead_op_desc.SetAttr("out_threshold",
reshape2_qkv->Op()->GetAttr("out_threshold"));
}
auto* multihead = graph->CreateOpNode(&multihead_op_desc);
IR_NODE_LINK_TO(input0, multihead);
IR_NODE_LINK_TO(mul0_w, multihead);
IR_NODE_LINK_TO(eltadd0_b, multihead);
IR_NODE_LINK_TO(eltadd_qk_b, multihead);
IR_NODE_LINK_TO(multihead, reshape2_qkv_out);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING)
<< "Op compat check in trt_multihead_matmul_fuse_pass_v2 failed.";
return;
}
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out,
multihead_pattern);
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd0_b, eltadd0_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd0_out, eltadd0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1, eltadd1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1_b, eltadd1_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1_out, eltadd1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2, eltadd2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2_b, eltadd2_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2_out, eltadd2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qk, matmul_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qk_out, matmul_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk, eltadd_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_b, eltadd_qk_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern);
// If weights or biases in qkv's fc are shared by multiple multihead_matmul
// patterns, we do not support this kind of fusion, this pass will not take
// effect.
bool is_fc_params_shared =
mul0_w->outputs.size() > 1 || mul1_w->outputs.size() > 1 ||
mul2_w->outputs.size() > 1 || eltadd0_b->outputs.size() > 1 ||
eltadd1_b->outputs.size() > 1 || eltadd2_b->outputs.size() > 1;
if (is_fc_params_shared) {
return;
}
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w,
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b,
reshape2_0, reshape2_qkv_out, scale, scale_out, softmax_qk,
eltadd0, eltadd1, eltadd2, matmul_qk, reshape2_qkv);
std::unordered_set<const Node*> marked_nodes({eltadd0,
eltadd1,
eltadd2,
eltadd1_b,
eltadd2_b,
eltadd0_out,
eltadd1_out,
eltadd2_out,
reshape2_0,
reshape2_1,
reshape2_2,
reshape2_0_out,
reshape2_1_out,
reshape2_2_out,
transpose2_0,
transpose2_1,
transpose2_2,
transpose2_0_out,
transpose2_1_out,
transpose2_2_out,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
mul0,
mul1,
mul2,
mul0_out,
mul1_out,
mul2_out,
mul1_w,
mul2_w,
reshape2_qkv,
scale});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
void TrtMultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal(
"During the multiheadMatmul pass, The scope should not be null."));
int fusion_count = BuildFusionV2(graph, name_scope_, scope);
if (fusion_count > 0) {
bool use_varseqlen = Get<bool>("use_varseqlen");
std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if (use_varseqlen && pos_id != "" && mask_id != "") {
if (graph->Has(framework::ir::kEmbEltwiseLayernormPass)) {
VLOG(3) << "start varseqlen trt_multihead_matmul_fuse_pass_v2";
} else {
PADDLE_THROW(platform::errors::Fatal(
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen"));
}
} else if (!use_varseqlen && pos_id == "" && mask_id == "") {
VLOG(3) << "start no_varseqlen trt_multihead_matmul_fuse_pass_v2";
} else {
PADDLE_THROW(
platform::errors::Fatal("Use transformer'varseqlen need config: "
"use_varseqlen, set pos_id, set "
"mask_id. Or not use varseqlen, do not set "
"pos_id, set mask_id. Please "
"reconfig"));
}
graph->Set(kMultiheadMatmulPass, new bool(true));
}
AddStatis(fusion_count);
}
TrtMultiHeadMatmulV3FusePass::TrtMultiHeadMatmulV3FusePass() {
AddOpCompat(OpCompat("mul"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddInput("Y") // the shape shoule be (N*H, N*H)
.IsTensor()
.End()
.AddOutput("Out") // the shape shoule be (B, S, N*H)
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumEQ(2)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
// in bias, shape is (B, S, N*H),
// in biasqk, shape is (B, H, S, S)
.IsTensor()
.End()
.AddInput("Y")
// in bias, shape is (N*H)
// in biasqk, shape is (B, H, S, S)
.IsTensor()
.End()
// in bias, shape is (B, S, N*H)
// in biasqk, shape is (B, H, S, S)
.AddOutput("Out")
.IsTensor()
.End()
// in bias, it equal to 2
// in biasqk, it equal to -1 or 0
.AddAttr("axis")
.IsIntIn({2, -1, 0})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
.IsType<std::vector<int>>()
.End();
// -->: (B, S, H, N) -> (B, H, S, N)
// <--: (B, H, S, N) -> (B, S, H, N)
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
// QK (B, H, S, N)*(B, H, S, N) -> (B, H, S, S)
// QKV (B, H, S, S)*(B, H, S, N) -> (B, H, S, N)
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsType<float>() // QK(anyvalue, will copy to new op) QKV(1.0)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y") // QK(true) QKV(false)
.IsType<bool>()
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsBoolEQ(false)
.End()
.AddAttr("trans_y") // QK(true) QKV(false)
.IsType<bool>()
.End();
AddOpCompat(OpCompat("softmax"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 3}) // shape is (B, H, S, S), so axis is -1 or 3
.End();
}
int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
const std::string& name_scope,
Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::TrtMultiHeadMatmulV3Pattern multihead_pattern(pattern, name_scope);
multihead_pattern();
// Create New OpDesc
auto fuse_creater = [&](
Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w,
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b,
Node* reshape2, Node* reshape2_qkv_out, Node* matmul_qk) {
auto scale_attr = BOOST_GET_CONST(float, matmul_qk->Op()->GetAttr("alpha"));
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
// bias (B * S * 3 * N * H) + bias (3 * N * H)
// Transpose (B * S * 3 * N * H) -> (3 * B * N * S * H)
auto* wq_tensor = scope->FindVar(mul0_w->Name())->GetMutable<LoDTensor>();
auto* wk_tensor = scope->FindVar(mul1_w->Name())->GetMutable<LoDTensor>();
auto* wv_tensor = scope->FindVar(mul2_w->Name())->GetMutable<LoDTensor>();
auto* bq_tensor =
scope->FindVar(eltadd0_b->Name())->GetMutable<LoDTensor>();
auto* bk_tensor =
scope->FindVar(eltadd1_b->Name())->GetMutable<LoDTensor>();
auto* bv_tensor =
scope->FindVar(eltadd2_b->Name())->GetMutable<LoDTensor>();
auto* wq_data = wq_tensor->mutable_data<float>(platform::CPUPlace());
auto* wk_data = wk_tensor->mutable_data<float>(platform::CPUPlace());
auto* wv_data = wv_tensor->mutable_data<float>(platform::CPUPlace());
auto* bq_data = bq_tensor->mutable_data<float>(platform::CPUPlace());
auto* bk_data = bk_tensor->mutable_data<float>(platform::CPUPlace());
auto* bv_data = bv_tensor->mutable_data<float>(platform::CPUPlace());
auto combined_w_dims =
phi::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
auto combined_bias_dims = phi::make_ddim({3, bq_tensor->dims()[0]});
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto* combined_w_desc = mul0_w->Var();
combined_w_desc->SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
combined_w_desc->SetPersistable(true);
auto* combined_bias_desc = eltadd0_b->Var();
combined_bias_desc->SetShape({3, bq_tensor->dims()[0]});
combined_bias_desc->SetPersistable(true);
framework::LoDTensor tmp_combined_w_tensor;
tmp_combined_w_tensor.Resize(combined_w_dims);
auto* tmp_combined_w_data =
tmp_combined_w_tensor.mutable_data<float>(platform::CPUPlace());
std::vector<float*> w_vec = {wq_data, wk_data, wv_data};
int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2];
// Combine the three fc weights together.
for (int i = 0; i < dims_h; i++) {
for (int j = 0; j < 3; j++) {
for (int k = 0; k < dims_w; k++) {
int out_index = i * (3 * dims_w) + j * dims_w + k;
int in_index = i * dims_w + k;
tmp_combined_w_data[out_index] = w_vec[j][in_index];
}
}
}
wq_tensor->Resize(combined_w_dims);
auto* new_combined_w_data =
wq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_w_data, tmp_combined_w_data,
sizeof(float) * wq_tensor->numel());
scope->EraseVars({mul1_w->Name(), mul2_w->Name()});
framework::LoDTensor tmp_combined_bias_tensor;
tmp_combined_bias_tensor.Resize(combined_bias_dims);
auto* tmp_combined_bias_data =
tmp_combined_bias_tensor.mutable_data<float>(platform::CPUPlace());
size_t bias_size = bq_tensor->numel();
memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + bias_size, bk_data,
sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + 2 * bias_size, bv_data,
sizeof(float) * bias_size);
bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data =
bq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_bias_data, tmp_combined_bias_data,
sizeof(float) * bq_tensor->numel());
scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});
auto reshape_desc = reshape2->Op();
int head_number =
BOOST_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape")).at(2);
OpDesc multihead_op_desc(mul0->Op()->Block());
multihead_op_desc.SetType("multihead_matmul");
multihead_op_desc.SetInput("Input", {input0->Name()});
multihead_op_desc.SetInput("W", {mul0_w->Name()});
multihead_op_desc.SetInput("Bias", {eltadd0_b->Name()});
multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()});
multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()});
multihead_op_desc.SetAttr("alpha", scale_attr);
multihead_op_desc.SetAttr("head_number", head_number);
auto* multihead = graph->CreateOpNode(&multihead_op_desc);
IR_NODE_LINK_TO(input0, multihead);
IR_NODE_LINK_TO(mul0_w, multihead);
IR_NODE_LINK_TO(eltadd0_b, multihead);
IR_NODE_LINK_TO(eltadd_qk_b, multihead);
IR_NODE_LINK_TO(multihead, reshape2_qkv_out);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out,
multihead_pattern);
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd0_b, eltadd0_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd0_out, eltadd0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1, eltadd1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1_b, eltadd1_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1_out, eltadd1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2, eltadd2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2_b, eltadd2_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2_out, eltadd2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qk, matmul_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qk_out, matmul_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk, eltadd_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_b, eltadd_qk_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern);
// If weights or biases in qkv's fc are shared by multiple multihead_matmul
// patterns, we do not support this kind of fusion, this pass will not take
// effect.
bool is_fc_params_shared =
mul0_w->outputs.size() > 1 || mul1_w->outputs.size() > 1 ||
mul2_w->outputs.size() > 1 || eltadd0_b->outputs.size() > 1 ||
eltadd1_b->outputs.size() > 1 || eltadd2_b->outputs.size() > 1;
if (is_fc_params_shared) {
return;
}
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w,
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b,
reshape2_0, reshape2_qkv_out, matmul_qk);
std::unordered_set<const Node*> marked_nodes({eltadd0,
eltadd1,
eltadd2,
eltadd1_b,
eltadd2_b,
eltadd0_out,
eltadd1_out,
eltadd2_out,
reshape2_0,
reshape2_1,
reshape2_2,
reshape2_0_out,
reshape2_1_out,
reshape2_2_out,
transpose2_0,
transpose2_1,
transpose2_2,
transpose2_0_out,
transpose2_1_out,
transpose2_2_out,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
mul0,
mul1,
mul2,
mul0_out,
mul1_out,
mul2_out,
mul1_w,
mul2_w,
reshape2_qkv});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
void TrtMultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal(
"During the multiheadMatmul pass, The scope should not be null."));
int fusion_count = BuildFusionV3(graph, name_scope_, scope);
if (fusion_count > 0) {
bool use_varseqlen = Get<bool>("use_varseqlen");
std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if (use_varseqlen && pos_id != "" && mask_id != "") {
if (graph->Has(framework::ir::kEmbEltwiseLayernormPass)) {
VLOG(3) << "start varseqlen trt_multihead_matmul_fuse_pass_v3";
} else {
PADDLE_THROW(platform::errors::Fatal(
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen"));
}
} else if (!use_varseqlen && pos_id == "" && mask_id == "") {
VLOG(3) << "start no_varseqlen trt_multihead_matmul_fuse_pass_v3";
} else {
PADDLE_THROW(
platform::errors::Fatal("Use transformer'varseqlen need config: "
"use_varseqlen, set pos_id, set "
"mask_id. Or not use varseqlen, do not set "
"pos_id, set mask_id. Please "
"reconfig"));
}
graph->Set(kMultiheadMatmulPass, new bool(true));
}
AddStatis(fusion_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(trt_multihead_matmul_fuse_pass,
paddle::framework::ir::TrtMultiHeadMatmulFusePass);
REGISTER_PASS(trt_multihead_matmul_fuse_pass_v2,
paddle::framework::ir::TrtMultiHeadMatmulV2FusePass);
REGISTER_PASS(trt_multihead_matmul_fuse_pass_v3,
paddle::framework::ir::TrtMultiHeadMatmulV3FusePass);
REGISTER_PASS_CAPABILITY(trt_multihead_matmul_fuse_pass_v2)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("mul", 0)
.LE("elementwise_add", 1)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.LE("matmul", 1)
.EQ("softmax", 0));
REGISTER_PASS_CAPABILITY(trt_multihead_matmul_fuse_pass_v3)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.EQ("softmax", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct TrtMultiHeadMatmulPattern : public PatternBase {
TrtMultiHeadMatmulPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "multihead_matmul") {}
PDNode* operator()();
// declare operator node's name
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(mul2);
PATTERN_DECL_NODE(mul0_w);
PATTERN_DECL_NODE(mul1_w);
PATTERN_DECL_NODE(mul2_w);
PATTERN_DECL_NODE(mul0_out);
PATTERN_DECL_NODE(mul1_out);
PATTERN_DECL_NODE(mul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(transpose2_qkv_out);
PATTERN_DECL_NODE(scale);
PATTERN_DECL_NODE(scale_out);
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
};
struct TrtMultiHeadMatmulV3Pattern : public PatternBase {
TrtMultiHeadMatmulV3Pattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "multihead_matmul_v3") {}
PDNode* operator()();
// declare operator node's name
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(mul2);
PATTERN_DECL_NODE(mul0_w);
PATTERN_DECL_NODE(mul1_w);
PATTERN_DECL_NODE(mul2_w);
PATTERN_DECL_NODE(mul0_out);
PATTERN_DECL_NODE(mul1_out);
PATTERN_DECL_NODE(mul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(transpose2_qkv_out);
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
};
} // namespace patterns
class TrtMultiHeadMatmulFusePass : public FusePassBase {
public:
virtual ~TrtMultiHeadMatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"trt_multihead_matmul_fuse"};
};
class TrtMultiHeadMatmulV2FusePass : public FusePassBase {
public:
TrtMultiHeadMatmulV2FusePass();
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"trt_multihead_matmul_fuse_v2"};
private:
int BuildFusionV2(Graph* graph, const std::string& name_scope,
Scope* scope) const;
};
class TrtMultiHeadMatmulV3FusePass : public FusePassBase {
public:
TrtMultiHeadMatmulV3FusePass();
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"trt_multihead_matmul_fuse_v3"};
private:
int BuildFusionV3(Graph* graph, const std::string& name_scope,
Scope* scope) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct TrtSkipLayerNorm : public PatternBase {
TrtSkipLayerNorm(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "skip_layernorm") {}
PDNode *operator()(PDNode *x, PDNode *y);
// declare operator node's name
PATTERN_DECL_NODE(elementwise);
PATTERN_DECL_NODE(layer_norm);
// declare variable node's name
PATTERN_DECL_NODE(
elementwise_out); // (elementwise_input_x,elementwise_input_y) ->
// elementwise_out
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
};
PDNode *TrtSkipLayerNorm::operator()(PDNode *x, PDNode *y) {
// Create nodes for elementwise add op.
x->assert_is_op_input("elementwise_add", "X");
y->assert_is_op_input("elementwise_add", "Y");
auto *elementwise =
pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add");
auto *elementwise_out_var =
pattern->NewNode(elementwise_out_repr())
->AsOutput()
->assert_is_only_output_of_op("elementwise_add");
// Add links for elementwise_add op.
elementwise->LinksFrom({x, y}).LinksTo({elementwise_out_var});
// Create nodes for layer_norm op.
elementwise_out_var->AsIntermediate()->assert_is_op_input("layer_norm");
auto *layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto *layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto *layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto *layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Y");
auto *layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto *layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
// Add links for layer_norm op.
layer_norm
->LinksFrom(
{elementwise_out_var, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
return layer_norm_out_var;
}
} // namespace patterns
void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("skip_layernorm_fuse", graph);
int found_subgraph_count = 0;
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
->NewNode("skip_layernorm_fuse/x")
->AsInput()
->assert_is_op_input("elementwise_add", "X")
->assert_var_not_persistable();
auto *y = gpd.mutable_pattern()
->NewNode("skip_layernorm_fuse/y")
->AsInput()
->assert_is_op_input("elementwise_add", "Y")
->assert_var_not_persistable();
patterns::TrtSkipLayerNorm fused_pattern(gpd.mutable_pattern(),
"skip_layernorm_fuse");
fused_pattern(x, y);
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) {
LOG(WARNING) << "The subgraph is empty.";
return;
}
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "skip_layernorm pass in op compat failed.";
return;
}
VLOG(4) << "handle TrtSkipLayerNorm fuse";
GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale,
fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance,
fused_pattern);
std::unordered_set<const Node *> del_node_set;
// Create an TrtSkipLayerNorm op node
OpDesc new_desc(elementwise->Op()->Block());
new_desc.SetType("skip_layernorm");
// inputs
new_desc.SetInput("X", {subgraph.at(x)->Name()});
new_desc.SetInput("Y", {subgraph.at(y)->Name()});
new_desc.SetInput("Scale", {layer_norm_scale->Name()});
new_desc.SetInput("Bias", {layer_norm_bias->Name()});
if (layer_norm->Op()->HasAttr("out_threshold")) {
new_desc.SetAttr("enable_int8", true);
new_desc.SetAttr("out_threshold",
layer_norm->Op()->GetAttr("out_threshold"));
}
// outputs
new_desc.SetOutput("Out", {layer_norm_out->Name()});
// attrs
new_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon"));
new_desc.SetAttr("begin_norm_axis",
layer_norm->Op()->GetAttr("begin_norm_axis"));
auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied.
del_node_set.insert(elementwise);
del_node_set.insert(layer_norm);
del_node_set.insert(elementwise_out);
del_node_set.insert(layer_norm_mean);
del_node_set.insert(layer_norm_variance);
GraphSafeRemoveNodes(graph, del_node_set);
IR_NODE_LINK_TO(subgraph.at(x), fused_node);
IR_NODE_LINK_TO(subgraph.at(y), fused_node);
IR_NODE_LINK_TO(layer_norm_scale, fused_node);
IR_NODE_LINK_TO(layer_norm_bias, fused_node);
IR_NODE_LINK_TO(fused_node, layer_norm_out);
found_subgraph_count++;
};
gpd(graph, handler);
if (found_subgraph_count > 0) {
bool use_varseqlen = Get<bool>("use_varseqlen");
std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if (use_varseqlen && pos_id != "" && mask_id != "") {
if (graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass)) {
VLOG(3) << "start varseqlen trt_skip_layernorm_fuse_pass";
} else {
PADDLE_THROW(platform::errors::Fatal(
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen"));
}
} else if (!use_varseqlen && pos_id == "" && mask_id == "") {
VLOG(3) << "start no_varseqlen trt_skip_layernorm_fuse_pass";
} else {
PADDLE_THROW(
platform::errors::Fatal("Use transformer'varseqlen need config: "
"use_varseqlen, set pos_id, set "
"mask_id. Or not use varseqlen, do not set "
"pos_id, set mask_id. Please "
"reconfig"));
}
}
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(trt_skip_layernorm_fuse_pass,
paddle::framework::ir::TrtSkipLayerNormFusePass);
REGISTER_PASS_CAPABILITY(trt_skip_layernorm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("layer_norm", 0));
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
// | | | |
// other_op1 other_op2 other_op1 other_op2
// | | fuse \ /
// |------elementwise_add -> skip_layernorm
// | |
// layer_norm other_op3
// | |
// other_op3
// |
class Graph;
class TrtSkipLayerNormFusePass : public FusePassBase {
public:
TrtSkipLayerNormFusePass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({0, -1})
.End();
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
}
virtual ~TrtSkipLayerNormFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -216,8 +216,12 @@ struct Argument {
DECL_ARGUMENT_FIELD(tensorrt_use_static_engine, TensorRtUseStaticEngine,
bool);
DECL_ARGUMENT_FIELD(tensorrt_use_calib_mode, TensorRtUseCalibMode, bool);
DECL_ARGUMENT_FIELD(tensorrt_use_oss, TensorRtUseOSS, bool);
DECL_ARGUMENT_FIELD(tensorrt_use_varseqlen, TensorRtUseOSS, bool);
DECL_ARGUMENT_FIELD(tensorrt_with_interleaved, TensorRtWithInterleaved, bool);
DECL_ARGUMENT_FIELD(tensorrt_transformer_posid, TensorRtTransformerPosid,
std::string);
DECL_ARGUMENT_FIELD(tensorrt_transformer_maskid, TensorRtTransformerMaskid,
std::string);
DECL_ARGUMENT_FIELD(tensorrt_shape_range_info_path,
TensorRtShapeRangeInfoPath, std::string);
DECL_ARGUMENT_FIELD(tensorrt_tuned_dynamic_shape, TensorRtTunedDynamicShape,
......
......@@ -55,9 +55,13 @@ void IRPassManager::CreatePasses(Argument *argument,
int pass_num = 0;
for (const std::string &pass_name : passes) {
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
pass->Set("use_oss", new bool(argument->tensorrt_use_oss()));
pass->Set("use_varseqlen", new bool(argument->tensorrt_use_varseqlen()));
pass->Set("with_interleaved",
new bool(argument->tensorrt_with_interleaved()));
pass->Set("tensorrt_transformer_posid",
new std::string(argument->tensorrt_transformer_posid()));
pass->Set("tensorrt_transformer_maskid",
new std::string(argument->tensorrt_transformer_maskid()));
pass->Set("disable_logs", new bool(argument->disable_logs()));
auto precision_mode = argument->tensorrt_precision_mode();
bool enable_int8 = precision_mode == AnalysisConfig::Precision::kInt8;
......
......@@ -377,12 +377,18 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
Get<int>("workspace_size"), precision_mode, calibrator.get(),
Get<int>("gpu_device_id"), min_input_shape, max_input_shape,
opt_input_shape, disable_trt_plugin_fp16);
trt_engine->SetUseOSS(Get<bool>("use_oss"));
trt_engine->SetUseOSS(Get<bool>("use_varseqlen"));
trt_engine->SetWithInterleaved(Get<bool>("with_interleaved"));
trt_engine->SetTransformerPosid(
Get<std::string>("tensorrt_transformer_posid"));
trt_engine->SetTransformerMaskid(
Get<std::string>("tensorrt_transformer_maskid"));
trt_engine->SetUseDLA(Get<bool>("trt_use_dla"));
trt_engine->SetDLACore(Get<int>("trt_dla_core"));
trt_engine->SetUseInspector(Get<bool>("use_inspector"));
trt_engine->SetWithErnie(graph->Has(framework::ir::kMultiheadMatmulPass));
trt_engine->SetWithErnie(
graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass));
if (use_static_engine) {
trt_engine_serialized_data = GetTrtEngineSerializedData(
......
......@@ -256,8 +256,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(trt_dla_core_);
CP_MEMBER(trt_use_static_engine_);
CP_MEMBER(trt_use_calib_mode_);
CP_MEMBER(trt_use_oss_);
CP_MEMBER(trt_use_varseqlen_);
CP_MEMBER(trt_with_interleaved_);
CP_MEMBER(tensorrt_transformer_posid_);
CP_MEMBER(tensorrt_transformer_maskid_);
CP_MEMBER(trt_tuned_dynamic_shape_);
CP_MEMBER(trt_allow_build_at_runtime_);
CP_MEMBER(collect_shape_range_info_);
......@@ -546,7 +548,7 @@ void AnalysisConfig::Exp_DisableTensorRtOPs(
trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end());
}
void AnalysisConfig::EnableTensorRtOSS() { trt_use_oss_ = true; }
void AnalysisConfig::EnableVarseqlen() { trt_use_varseqlen_ = true; }
// TODO(Superjomn) refactor this, buggy.
void AnalysisConfig::Update() {
......@@ -1034,9 +1036,13 @@ std::string AnalysisConfig::Summary() {
? shape_range_info_path_
: "false"});
os.InsertRow({"tensorrt_use_oss", trt_use_oss_ ? "true" : "false"});
os.InsertRow(
{"tensorrt_use_varseqlen", trt_use_varseqlen_ ? "true" : "false"});
os.InsertRow({"tensorrt_with_interleaved",
trt_with_interleaved_ ? "true" : "false"});
os.InsertRow({"tensorrt_transformer_posid", tensorrt_transformer_posid_});
os.InsertRow(
{"tensorrt_transformer_maskid", tensorrt_transformer_maskid_});
os.InsertRow({"tensorrt_use_dla", trt_use_dla_ ? "true" : "false"});
if (trt_use_dla_) {
os.InsertRow({"tensorrt_dla_core", std::to_string(trt_dla_core_)});
......
......@@ -853,8 +853,10 @@ void AnalysisPredictor::PrepareArgument() {
}
argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
argument_.SetTensorRtUseOSS(config_.trt_use_oss_);
argument_.SetTensorRtUseOSS(config_.trt_use_varseqlen_);
argument_.SetTensorRtWithInterleaved(config_.trt_with_interleaved_);
argument_.SetTensorRtTransformerPosid(config_.tensorrt_transformer_posid_);
argument_.SetTensorRtTransformerMaskid(config_.tensorrt_transformer_maskid_);
argument_.SetMinInputShape(config_.min_input_shape_);
argument_.SetMaxInputShape(config_.max_input_shape_);
argument_.SetOptimInputShape(config_.optim_input_shape_);
......@@ -1803,6 +1805,9 @@ USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm)
USE_TRT_CONVERTER(preln_skip_layernorm)
USE_TRT_CONVERTER(roll)
USE_TRT_CONVERTER(strided_slice)
USE_TRT_CONVERTER(transformer_input_convert)
USE_TRT_CONVERTER(recover_padding)
USE_TRT_CONVERTER(remove_padding)
#endif
namespace paddle_infer {
......@@ -1971,6 +1976,20 @@ void InternalUtils::UpdateConfigInterleaved(paddle_infer::Config *c,
#endif
}
void InternalUtils::SetTransformerPosid(
paddle_infer::Config *c, const std::string &tensorrt_transformer_posid) {
#ifdef PADDLE_WITH_CUDA
c->tensorrt_transformer_posid_ = tensorrt_transformer_posid;
#endif
}
void InternalUtils::SetTransformerMaskid(
paddle_infer::Config *c, const std::string &tensorrt_transformer_maskid) {
#ifdef PADDLE_WITH_CUDA
c->tensorrt_transformer_maskid_ = tensorrt_transformer_maskid;
#endif
}
void InternalUtils::SyncStream(paddle_infer::Predictor *p) {
#ifdef PADDLE_WITH_CUDA
auto *pred = dynamic_cast<paddle::AnalysisPredictor *>(p->predictor_.get());
......
......@@ -618,14 +618,14 @@ struct PD_INFER_DECL AnalysisConfig {
/// may be more high-performance. Libnvinfer_plugin.so greater than
/// V7.2.1 is needed.
///
void EnableTensorRtOSS();
void EnableVarseqlen();
///
/// \brief A boolean state telling whether to use the TensorRT OSS.
///
/// \return bool Whether to use the TensorRT OSS.
///
bool tensorrt_oss_enabled() { return trt_use_oss_; }
bool tensorrt_varseqlen_enabled() { return trt_use_varseqlen_; }
///
/// \brief Enable TensorRT DLA
......@@ -954,8 +954,10 @@ struct PD_INFER_DECL AnalysisConfig {
Precision tensorrt_precision_mode_{Precision::kFloat32};
bool trt_use_static_engine_{false};
bool trt_use_calib_mode_{true};
bool trt_use_oss_{false};
bool trt_use_varseqlen_{false};
bool trt_with_interleaved_{false};
std::string tensorrt_transformer_posid_{""};
std::string tensorrt_transformer_maskid_{""};
bool trt_use_dla_{false};
int trt_dla_core_{0};
std::map<std::string, std::vector<int>> min_input_shape_{};
......
......@@ -435,6 +435,12 @@ class PD_INFER_DECL InternalUtils {
static void UpdateConfigInterleaved(paddle_infer::Config* c,
bool with_interleaved);
static void SetTransformerPosid(
paddle_infer::Config* c, const std::string& tensorrt_transformer_posid);
static void SetTransformerMaskid(
paddle_infer::Config* c, const std::string& tensorrt_transformer_maskid);
static void SyncStream(paddle_infer::Predictor* pred);
static void SyncStream(cudaStream_t stream);
template <typename T>
......
......@@ -94,11 +94,11 @@ const std::vector<std::string> kTRTSubgraphPasses({
"add_support_int8_pass", //
// "fc_fuse_pass", //
"simplify_with_basic_ops_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
"trt_embedding_eltwise_layernorm_fuse_pass", //
"preln_embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", //
"multihead_matmul_fuse_pass_v3", //
"skip_layernorm_fuse_pass", //
"trt_multihead_matmul_fuse_pass_v2", //
"trt_multihead_matmul_fuse_pass_v3", //
"trt_skip_layernorm_fuse_pass", //
"preln_skip_layernorm_fuse_pass", //
// "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass", //
......@@ -111,8 +111,8 @@ const std::vector<std::string> kTRTSubgraphPasses({
"trt_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
// "remove_padding_recover_padding_pass", //
// "delete_remove_padding_recover_padding_pass", //
"remove_padding_recover_padding_pass", //
"delete_remove_padding_recover_padding_pass", //
// "yolo_box_fuse_pass", //
"tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", //
......
......@@ -303,13 +303,13 @@ void PD_ConfigDisableTensorRtOPs(__pd_keep PD_Config* pd_config, size_t ops_num,
config->Exp_DisableTensorRtOPs(ops_list);
}
void PD_ConfigEnableTensorRtOSS(__pd_keep PD_Config* pd_config) {
void PD_ConfigEnableVarseqlen(__pd_keep PD_Config* pd_config) {
CHECK_AND_CONVERT_PD_CONFIG;
config->EnableTensorRtOSS();
config->EnableVarseqlen();
}
PD_Bool PD_ConfigTensorRtOssEnabled(__pd_keep PD_Config* pd_config) {
CHECK_AND_CONVERT_PD_CONFIG;
return config->tensorrt_oss_enabled();
return config->tensorrt_varseqlen_enabled();
}
void PD_ConfigEnableTensorRtDla(__pd_keep PD_Config* pd_config,
......
......@@ -432,7 +432,7 @@ PADDLE_CAPI_EXPORT extern void PD_ConfigDisableTensorRtOPs(
///
/// \param[in] pd_onfig config
///
PADDLE_CAPI_EXPORT extern void PD_ConfigEnableTensorRtOSS(
PADDLE_CAPI_EXPORT extern void PD_ConfigEnableVarseqlen(
__pd_keep PD_Config* pd_config);
///
/// \brief A boolean state telling whether to use the TensorRT OSS.
......
......@@ -500,8 +500,8 @@ func (config *Config) DisableTensorRtOPs(ops []string) {
/// may be more high-performance. Libnvinfer_plugin.so greater than
/// V7.2.1 is needed.
///
func (config *Config) EnableTensorRtOSS() {
C.PD_ConfigEnableTensorRtOSS(config.c)
func (config *Config) EnableVarseqlen() {
C.PD_ConfigEnableVarseqlen(config.c)
}
///
......
......@@ -54,7 +54,7 @@ func TestNewConfig(t *testing.T) {
}
config.SetTRTDynamicShapeInfo(minInputShape, maxInputShape, optInputShape, false)
config.EnableTensorRtOSS()
config.EnableVarseqlen()
t.Logf("TensorrtOssEnabled:%+v", config.TensorrtOssEnabled())
config.EnableTensorRtDLA(0)
......
......@@ -56,7 +56,11 @@ nv_library(tensorrt_converter
strided_slice_op.cc
preln_skip_layernorm.cc
roll_op.cc
transformer_input_convert_op.cc
remove_padding_op.cc
recover_padding_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_converter)
......@@ -30,23 +30,28 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
#if IS_TRT_VERSION_GE(6000)
VLOG(4) << "convert fluid EmbEltwiseLayerNorm op to tensorrt layer";
framework::OpDesc op_desc(op, nullptr);
auto word_id_name = op_desc.Input("WordId").front();
auto pos_id_name = op_desc.Input("PosId").front();
auto pos_id_name = engine_->tensorrt_transformer_posid();
engine_->Set("ernie_pos_name", new std::string(pos_id_name));
auto sent_id_name = op_desc.Input("SentId").front();
auto mask_id_name = engine_->tensorrt_transformer_maskid();
auto word_emb_name = op_desc.Input("WordEmbedding").front();
auto pos_emb_name = op_desc.Input("PosEmbedding").front();
auto sent_emb_name = op_desc.Input("SentEmbedding").front();
std::vector<std::string> id_names;
std::vector<std::string> emb_names;
bool flag_varseqlen =
engine_->use_varseqlen() && pos_id_name != "" && mask_id_name != "";
if (engine_->use_oss()) {
if (flag_varseqlen) {
engine_->SetITensor("word_id", engine_->GetITensor(word_id_name));
engine_->SetITensor("pos_id", engine_->GetITensor(pos_id_name));
engine_->SetITensor("mask_id", engine_->GetITensor(mask_id_name));
id_names =
std::vector<std::string>{word_id_name, pos_id_name, sent_id_name};
emb_names =
......@@ -106,7 +111,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr;
bool enable_int8 = op_desc.HasAttr("enable_int8");
if (engine_->use_oss()) {
if (flag_varseqlen) {
int output_fp16 = static_cast<int>((engine_->WithFp16() == 1) ? 1 : 0);
if (enable_int8) {
output_fp16 = 1;
......@@ -121,7 +126,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
output_fp16, 1,
platform::errors::InvalidArgument(
"Only Precision::KHalf(fp16) is supported when infering "
"ernie(bert) model with config.EnableTensorRtOSS(). "
"ernie(bert) model with config.EnableVarseqlen(). "
"But Precision::KFloat32 is setted."));
const std::vector<nvinfer1::PluginField> fields{
{"bert_embeddings_layernorm_beta", bias,
......@@ -159,8 +164,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
plugin_inputs.emplace_back(
engine_->GetITensor(pos_id_name)); // cu_seqlens,
// eval_placeholder_2
auto max_seqlen_tensor =
engine_->GetITensor(engine_->network()->getInput(3)->getName());
auto max_seqlen_tensor = engine_->GetITensor(mask_id_name);
auto* shuffle_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *max_seqlen_tensor);
nvinfer1::Dims shape_dim;
......@@ -193,8 +197,8 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
engine_->SetTensorDynamicRange(plugin_layer->getOutput(1), out_scale);
}
if (engine_->with_interleaved()) {
VLOG(4)
<< "fused emb_eltwise_layernorm op: use_oss and with_interleaved";
VLOG(4) << "fused emb_eltwise_layernorm op: use_varseqlen and "
"with_interleaved";
if (!enable_int8) {
PADDLE_THROW(
platform::errors::Fatal("use with_interleaved must be int8."));
......@@ -229,12 +233,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
RreplenishLayerAndOutput(layer, "emb_eltwise_layernorm", {output_name},
test_mode);
}
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
}
};
......
......@@ -250,8 +250,7 @@ class FcOpConverter : public OpConverter {
}
// If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can
// not add Shuffle layer in ernie's multihead.
if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 &&
x_dim.d[3] == 1 && x_num_col_dims == 2) {
if (x_dim.nbDims == 4 && x_num_col_dims == 1) {
if (enable_int8 || support_int8) {
// add conv1x1 layer
nvinfer1::DimsHW nv_ksize(1, 1);
......
......@@ -76,12 +76,14 @@ class MultiheadMatMulOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr;
auto output_name = op_desc.Output("Out")[0];
bool flag_varseqlen = engine_->use_varseqlen() &&
engine_->tensorrt_transformer_posid() != "" &&
engine_->tensorrt_transformer_maskid() != "";
if (engine_->with_dynamic_shape()) {
if (engine_->use_oss()) {
if (flag_varseqlen) {
if (engine_->precision() == AnalysisConfig::Precision::kFloat32) {
PADDLE_THROW(platform::errors::Fatal(
"use use_oss must be int8 or half, not float32."));
"use use_varseqlen must be int8 or half, not float32."));
}
nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
......@@ -90,7 +92,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
static_cast<void*>(bias_data),
static_cast<int32_t>(bias_t->numel())};
if (engine_->with_interleaved()) {
VLOG(4) << "fused multihead_matmul op: use_oss and with_interleaved";
VLOG(4) << "fused multihead_matmul op: use_varseqlen and "
"with_interleaved";
if (!op_desc.HasAttr("Input_scale")) {
PADDLE_THROW(
platform::errors::Fatal("use with_interleaved must be int8."));
......@@ -233,9 +236,6 @@ class MultiheadMatMulOpConverter : public OpConverter {
BOOST_GET_CONST(float, op_desc.GetAttr("dp_probs")) / 127.0;
}
}
auto mask_tensor = engine_->GetITensor("qkv_plugin_mask");
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomQKVToContextPluginDynamic", "2");
assert(creator != nullptr);
......@@ -272,18 +272,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(fc_layer->getOutput(0));
plugin_inputs.emplace_back(mask_tensor);
if (engine_->Has("ernie_pos_name")) {
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->Get<std::string>("ernie_pos_name")));
} else {
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->network()
->getInput(2)
->getName())); // cu_seqlens, eval_placeholder_2
}
auto max_seqlen_tensor =
engine_->GetITensor(engine_->network()->getInput(3)->getName());
plugin_inputs.emplace_back(engine_->GetITensor("qkv_plugin_mask"));
plugin_inputs.emplace_back(engine_->GetITensor("pos_id"));
auto max_seqlen_tensor = engine_->GetITensor("mask_id");
auto* shuffle_layer = TRT_ENGINE_ADD_LAYER(
engine_, Shuffle,
*const_cast<nvinfer1::ITensor*>(max_seqlen_tensor));
......
......@@ -32,7 +32,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
#if IS_TRT_VERSION_GE(7000)
VLOG(4) << "convert fluid PrelnEmbEltwiseLayerNorm op to tensorrt layer";
if (!(engine_->use_oss() && engine_->with_interleaved())) {
if (!(engine_->use_varseqlen() && engine_->with_interleaved())) {
PADDLE_THROW(platform::errors::Fatal(
"PrelnErnie: If you want to use oss, must be with interleaved"));
}
......
......@@ -24,7 +24,7 @@ class PrelnSkipLayerNormOpConverter : public OpConverter {
const framework::Scope& scope, bool test_mode) override {
#if IS_TRT_VERSION_GE(7000)
VLOG(4) << "convert fused preln_skip_layernorm op to tensorrt layer";
if (!(engine_->use_oss() && engine_->with_interleaved())) {
if (!(engine_->use_varseqlen() && engine_->with_interleaved())) {
PADDLE_THROW(platform::errors::Fatal(
"PrelnErnie: If you want to use oss, must be with interleaved"));
}
......@@ -60,7 +60,8 @@ class PrelnSkipLayerNormOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr;
VLOG(4) << "fused preln_skip_layernorm op: use_oss and with_interleaved";
VLOG(4)
<< "fused preln_skip_layernorm op: use_varseqlen and with_interleaved";
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomSkipLayerNormPluginDynamic", "4");
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Recover padding of transformer'input.
*/
class RecoverPadding : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "Recover padding of transformer'output: VarSeqlen -> Padding.";
if (!engine_->with_dynamic_shape()) {
PADDLE_THROW(platform::errors::Fatal(
"recover_padding_op: If you want to use transformer, must "
"be with dynamic shape"));
}
framework::OpDesc op_desc(op, nullptr);
/*
auto x_var_name = op_desc.Input(InputNames()).front();
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
*/
auto input_name = op_desc.Input("Input").front();
std::cout << "input_name: " << input_name << std::endl;
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.push_back(engine_->GetITensor(input_name));
plugin_inputs.push_back(engine_->GetITensor("pos_id"));
plugin_inputs.push_back(engine_->GetITensor("mask_id"));
int input_num = 3;
auto output_name = op_desc.Output("Out").front();
plugin::RecoverPaddingPlugin* plugin = new plugin::RecoverPaddingPlugin();
nvinfer1::ILayer* layer =
engine_->AddDynamicPlugin(plugin_inputs.data(), input_num, plugin);
RreplenishLayerAndOutput(layer, "recover_padding", {output_name},
test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(recover_padding, RecoverPadding);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Remove padding of transformer'input.
*/
class RemovePadding : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "Remove padding of transformer'input: Padding -> VarSeqlen";
if (!engine_->with_dynamic_shape()) {
PADDLE_THROW(platform::errors::Fatal(
"remove_padding_op: If you want to use transformer, must "
"be with dynamic shape"));
}
framework::OpDesc op_desc(op, nullptr);
auto input_name = op_desc.Input("Input").front();
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.push_back(engine_->GetITensor(input_name));
plugin_inputs.push_back(engine_->GetITensor("pos_id"));
plugin_inputs.push_back(engine_->GetITensor("word_id"));
size_t input_num = plugin_inputs.size();
auto output_name = op_desc.Output("Out").front();
plugin::RemovePaddingPlugin* plugin = new plugin::RemovePaddingPlugin();
nvinfer1::ILayer* layer =
engine_->AddDynamicPlugin(plugin_inputs.data(), input_num, plugin);
RreplenishLayerAndOutput(layer, "remove_padding_op", {output_name},
test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(remove_padding, RemovePadding);
......@@ -52,10 +52,13 @@ class SkipLayerNormOpConverter : public OpConverter {
bool enable_int8 = op_desc.HasAttr("enable_int8");
nvinfer1::ILayer* layer = nullptr;
if (engine_->use_oss()) {
bool flag_varseqlen = engine_->use_varseqlen() &&
engine_->tensorrt_transformer_posid() != "" &&
engine_->tensorrt_transformer_maskid() != "";
if (flag_varseqlen) {
if (engine_->with_interleaved()) {
VLOG(4) << "fused skip_layernorm op: use_oss and with_interleaved";
VLOG(4)
<< "fused skip_layernorm op: use_varseqlen and with_interleaved";
if (!enable_int8) {
PADDLE_THROW(
platform::errors::Fatal("use with_interleaved must be int8."));
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
namespace paddle {
namespace inference {
......@@ -74,47 +73,12 @@ class SliceOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
if (engine_->use_oss() && engine_->with_ernie() &&
input_dims.nbDims == 4) {
std::vector<nvinfer1::ITensor*> plugin_inputs;
if (engine_->with_interleaved()) {
auto* shuffler_slice = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
nvinfer1::Permutation transpose_embed{2, 1, 0, 3};
shuffler_slice->setSecondTranspose(transpose_embed);
engine_->SetTensorDynamicRange(shuffler_slice->getOutput(0),
out_scale);
shuffler_slice->setName(
("SpecialSlice_interleaved: transpose: (Output: " + output_name +
")")
.c_str());
plugin_inputs.emplace_back(shuffler_slice->getOutput(0));
} else {
plugin_inputs.emplace_back(input);
}
std::string pos_name;
if (engine_->Has("ernie_pos_name")) {
pos_name = engine_->Get<std::string>("ernie_pos_name");
} else {
// hard code for compatibility
pos_name = engine_->network()->getInput(2)->getName();
}
plugin_inputs.emplace_back(
engine_->GetITensor(pos_name)); // cu_seqlens, eval_placeholder_2
// bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::SpecialSlicePluginDynamic* plugin =
new plugin::SpecialSlicePluginDynamic();
layer = engine_->AddDynamicPlugin(plugin_inputs.data(),
plugin_inputs.size(), plugin);
} else {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
int decrease_axis =
decrease_axises.size() == 0 ? -1 : decrease_axises[0];
int decrease_axis = decrease_axises.size() == 0 ? -1 : decrease_axises[0];
plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic(
starts, ends, axes, decrease_axis, with_fp16);
layer = engine_->AddDynamicPlugin(&input, 1, plugin);
}
} else {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Convert Transformer Input(pos_id, max_seqlen).
*/
class TransformerInputConvert : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "Convert Transformer Input(pos_id, max_seqlen), use "
"transformer_input_convert_plugin";
if (!engine_->with_dynamic_shape()) {
PADDLE_THROW(platform::errors::Fatal(
"transformer_input_convert_op: If you want to use transformer, must "
"be with dynamic shape"));
}
framework::OpDesc op_desc(op, nullptr);
auto input_name = op_desc.Input("Input").front();
auto* input = engine_->GetITensor(input_name);
int input_num = op_desc.Input("Input").size();
// tensorrt_subgraph_pass will rename tensor
// auto pos_id_name = op_desc.Output("PosId").front();
// auto max_seqlen_name = op_desc.Output("MaxSeqlen").front();
auto pos_id_name = "pos_id_tensor";
auto max_seqlen_name = "max_seqlen_tensor";
plugin::TransformerInputConvertPlugin* plugin =
new plugin::TransformerInputConvertPlugin();
nvinfer1::ILayer* layer =
engine_->AddDynamicPlugin(&input, input_num, plugin);
RreplenishLayerAndOutput(layer, "transformer_input_convert",
{pos_id_name, max_seqlen_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(transformer_input_convert, TransformerInputConvert);
......@@ -410,14 +410,19 @@ class TensorRTEngine {
suffix_counter += 1;
}
void SetUseOSS(bool use_oss) { use_oss_ = use_oss; }
void SetUseOSS(bool use_varseqlen) { use_varseqlen_ = use_varseqlen; }
void SetUseDLA(bool use_dla) { use_dla_ = use_dla; }
void SetDLACore(int dla_core) { dla_core_ = dla_core; }
void SetWithErnie(bool with_ernie) { with_ernie_ = with_ernie; }
void SetWithInterleaved(bool with_interleaved) {
with_interleaved_ = with_interleaved;
}
void SetTransformerPosid(std::string tensorrt_transformer_posid) {
tensorrt_transformer_posid_ = tensorrt_transformer_posid;
}
void SetTransformerMaskid(std::string tensorrt_transformer_maskid) {
tensorrt_transformer_maskid_ = tensorrt_transformer_maskid;
}
void ClearWeights() {
for (auto& weight_pair : weight_map) {
weight_pair.second.reset(nullptr);
......@@ -488,9 +493,15 @@ class TensorRTEngine {
return ret;
}
bool use_oss() { return use_oss_; }
bool use_varseqlen() { return use_varseqlen_; }
bool with_ernie() { return with_ernie_; }
bool with_interleaved() { return with_interleaved_; }
std::string tensorrt_transformer_posid() {
return tensorrt_transformer_posid_;
}
std::string tensorrt_transformer_maskid() {
return tensorrt_transformer_maskid_;
}
bool disable_trt_plugin_fp16() { return disable_trt_plugin_fp16_; }
bool with_dynamic_shape() { return with_dynamic_shape_; }
AnalysisConfig::Precision precision() { return precision_; }
......@@ -612,11 +623,13 @@ class TensorRTEngine {
ShapeMapType max_input_shape_;
ShapeMapType optim_input_shape_;
bool disable_trt_plugin_fp16_{false};
bool use_oss_{false};
bool use_varseqlen_{false};
bool use_dla_{false};
int dla_core_{0};
bool with_ernie_{false};
bool with_interleaved_{false};
std::string tensorrt_transformer_posid_;
std::string tensorrt_transformer_maskid_;
nvinfer1::ILogger& logger_;
// max data size for the buffers.
......
......@@ -125,7 +125,10 @@ struct SimpleOpTypeSetTeller : public Teller {
"strided_slice",
"fused_preln_embedding_eltwise_layernorm",
"roll",
"preln_skip_layernorm"};
"preln_skip_layernorm",
"transformer_input_convert",
"recover_padding",
"remove_padding"};
std::unordered_set<std::string> teller_set{
"mul",
"matmul",
......@@ -194,7 +197,10 @@ struct SimpleOpTypeSetTeller : public Teller {
"fused_preln_embedding_eltwise_layernorm",
"preln_skip_layernorm",
"roll",
"multiclass_nms3"};
"multiclass_nms3",
"transformer_input_convert",
"recover_padding",
"remove_padding"};
};
bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
......
......@@ -4,7 +4,7 @@ nv_library(tensorrt_plugin
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu
hard_swish_op_plugin.cu stack_op_plugin.cu
anchor_generator_op_plugin.cu
yolo_box_op_plugin.cu
yolo_box_head_op_plugin.cu
......@@ -14,6 +14,9 @@ nv_library(tensorrt_plugin
pool3d_op_plugin.cu
deformable_conv_op_plugin.cu
matmul_op_int8_plugin.cu
transformer_input_convert_plugin.cu
remove_padding_plugin.cu
recover_padding_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
nv_test(test_split_plugin SRCS test_split_plugin.cc DEPS
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
__global__ void RecoverPaddingKernel(const float* input0, const int32_t* input1,
float* output) {
int word_id = blockIdx.x * gridDim.y + blockIdx.y;
int32_t seqence_length = input1[blockIdx.x + 1] - input1[blockIdx.x];
if (blockIdx.y < seqence_length) {
output[word_id * gridDim.z * blockDim.x + blockIdx.z * blockDim.x +
threadIdx.x] =
input0[(input1[blockIdx.x] + blockIdx.y) * gridDim.z * blockDim.x +
blockIdx.z * blockDim.x + threadIdx.x];
} else {
output[word_id * gridDim.z * blockDim.x + blockIdx.z * blockDim.x +
threadIdx.x] = 0;
}
}
nvinfer1::DataType RecoverPaddingPlugin::getOutputDataType(
int index, const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT {
return input_types[0];
}
nvinfer1::DimsExprs RecoverPaddingPlugin::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT {
nvinfer1::DimsExprs output_dims{};
output_dims.nbDims = 3;
const auto* one = exprBuilder.constant(1);
output_dims.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kSUB,
*inputs[1].d[0], *one);
output_dims.d[1] = inputs[2].d[1];
output_dims.d[2] = inputs[0].d[1];
return output_dims;
}
bool RecoverPaddingPlugin::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
int nbOutputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(nbInputs, 3,
platform::errors::InvalidArgument("Must have 3 inputs, "
"but got %d input(s). ",
nbInputs));
PADDLE_ENFORCE_EQ(nbOutputs, getNbOutputs(),
platform::errors::InvalidArgument("Must have 1 output, "
"but got %d output(s). ",
nbOutputs));
if (pos == 1) { // PosId, MaxSeqlen
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
}
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
// return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format
// == nvinfer1::TensorFormat::kLINEAR)||
// (inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format ==
// nvinfer1::TensorFormat::kLINEAR)||
// (inOut[pos].type == nvinfer1::DataType::kINT8 && inOut[pos].format ==
// nvinfer1::TensorFormat::kCHW32);
}
void RecoverPaddingPlugin::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) TRT_NOEXCEPT {}
void RecoverPaddingPlugin::attachToContext(
cudnnContext* cudnnContext, cublasContext* cublasContext,
nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {}
void RecoverPaddingPlugin::detachFromContext() TRT_NOEXCEPT {}
void RecoverPaddingPlugin::terminate() TRT_NOEXCEPT {}
int RecoverPaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT {
const auto input0_desc = inputDesc[0];
const auto input1_desc = inputDesc[1];
const auto input2_desc = inputDesc[2];
const float* input0 = static_cast<const float*>(inputs[0]);
const int32_t* input1 =
static_cast<const int32_t*>(inputs[1]); // pos_id_tensor
float* output = static_cast<float*>(outputs[0]);
const int32_t num_threads = 256;
const dim3 num_blocks(
input1_desc.dims.d[0] - 1, input2_desc.dims.d[1],
input0_desc.dims.d[1] / num_threads); // batchs, max sequnce length
// (mask_id.dims.d[1]),
// input.dims.d[1]/256
RecoverPaddingKernel<<<num_blocks, num_threads, 0, stream>>>(input0, input1,
output);
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cassert>
#include <string>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class RecoverPaddingPlugin : public DynamicPluginTensorRT {
public:
RecoverPaddingPlugin() {}
RecoverPaddingPlugin(void const* serial_data, size_t serial_length) {}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
RecoverPaddingPlugin* ptr = new RecoverPaddingPlugin();
return ptr;
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "recover_padding_plugin";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT { return 0; }
void terminate() TRT_NOEXCEPT;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override {
return 0;
}
void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext,
nvinfer1::IGpuAllocator* gpuAllocator)
TRT_NOEXCEPT override;
void detachFromContext() TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(
int index, const nvinfer1::DataType* inputTypes,
int nbInputs) const TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
protected:
size_t getSerializationSize() const TRT_NOEXCEPT override { return 0; }
void serialize(void* buffer) const TRT_NOEXCEPT override {}
};
class RecoverPaddingPluginCreator : public nvinfer1::IPluginCreator {
public:
RecoverPaddingPluginCreator() {}
const char* getPluginName() const TRT_NOEXCEPT override {
return "recover_padding_plugin";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* plugin_field)
TRT_NOEXCEPT override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(
const char* name, void const* serial_data,
size_t serial_length) TRT_NOEXCEPT override {
RecoverPaddingPlugin* obj =
new RecoverPaddingPlugin(serial_data, serial_length);
obj->setPluginNamespace(name);
return obj;
}
void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const TRT_NOEXCEPT override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
};
REGISTER_TRT_PLUGIN_V2(RecoverPaddingPluginCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
__global__ void RemovePaddingKernel(const float* input0, const int32_t* input1,
float* output) {
int word_id = blockIdx.x * gridDim.y + blockIdx.y;
int32_t seqence_length = input1[blockIdx.x + 1] - input1[blockIdx.x];
if (blockIdx.y < seqence_length) {
output[(input1[blockIdx.x] + blockIdx.y) * gridDim.z * blockDim.x +
blockIdx.z * blockDim.x + threadIdx.x] =
input0[word_id * gridDim.z * blockDim.x + blockIdx.z * blockDim.x +
threadIdx.x];
}
}
nvinfer1::DataType RemovePaddingPlugin::getOutputDataType(
int index, const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT {
return input_types[0];
}
nvinfer1::DimsExprs RemovePaddingPlugin::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT {
nvinfer1::DimsExprs output_dims{};
output_dims.nbDims = 4;
output_dims.d[0] = inputs[2].d[0];
output_dims.d[1] = inputs[0].d[2];
output_dims.d[2] = exprBuilder.constant(1);
output_dims.d[3] = exprBuilder.constant(1);
return output_dims;
}
bool RemovePaddingPlugin::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
int nbOutputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(nbInputs, 3,
platform::errors::InvalidArgument("Must have 3 inputs, "
"but got %d input(s). ",
nbInputs));
PADDLE_ENFORCE_EQ(nbOutputs, getNbOutputs(),
platform::errors::InvalidArgument("Must have 1 output, "
"but got %d output(s). ",
nbOutputs));
if (pos == 1 || pos == 2) { // pos_id, work_id
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
}
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
// return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format
// == nvinfer1::TensorFormat::kLINEAR)||
// (inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format ==
// nvinfer1::TensorFormat::kLINEAR)||
// (inOut[pos].type == nvinfer1::DataType::kINT8 && inOut[pos].format ==
// nvinfer1::TensorFormat::kCHW32);
}
void RemovePaddingPlugin::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) TRT_NOEXCEPT {}
void RemovePaddingPlugin::attachToContext(
cudnnContext* cudnnContext, cublasContext* cublasContext,
nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {}
void RemovePaddingPlugin::detachFromContext() TRT_NOEXCEPT {}
void RemovePaddingPlugin::terminate() TRT_NOEXCEPT {}
int RemovePaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT {
const auto input_desc = inputDesc[0];
const float* input0 = static_cast<const float*>(inputs[0]);
const int32_t* input1 =
static_cast<const int32_t*>(inputs[1]); // pos_id_tensor
float* output = static_cast<float*>(outputs[0]);
const auto input0_desc = inputDesc[0];
const int32_t num_threads = 256;
const dim3 num_blocks(
input0_desc.dims.d[0], input0_desc.dims.d[1],
input0_desc.dims.d[2] /
num_threads); // batchs, max sequnce length, input.dims.d[2]/256
RemovePaddingKernel<<<num_blocks, num_threads, 0, stream>>>(input0, input1,
output);
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
class SpecialSlicePluginDynamic : public DynamicPluginTensorRT {
class RemovePaddingPlugin : public DynamicPluginTensorRT {
public:
SpecialSlicePluginDynamic();
SpecialSlicePluginDynamic(void const* serial_data, size_t serial_length);
~SpecialSlicePluginDynamic();
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override;
RemovePaddingPlugin() {}
RemovePaddingPlugin(void const* serial_data, size_t serial_length) {}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
RemovePaddingPlugin* ptr = new RemovePaddingPlugin();
return ptr;
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "remove_padding_plugin";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT { return 0; }
void terminate() TRT_NOEXCEPT;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override;
int nbOutputs) const TRT_NOEXCEPT override {
return 0;
}
void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext,
nvinfer1::IGpuAllocator* gpuAllocator)
TRT_NOEXCEPT override;
void detachFromContext() TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(
int index, const nvinfer1::DataType* inputTypes,
int nbInputs) const TRT_NOEXCEPT override;
const char* getPluginType() const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override;
int initialize() TRT_NOEXCEPT override;
void terminate() TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void* buffer) const TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
private:
int axis_;
int num_stack_;
protected:
size_t getSerializationSize() const TRT_NOEXCEPT override { return 0; }
void serialize(void* buffer) const TRT_NOEXCEPT override {}
};
class SpecialSlicePluginDynamicCreator : public nvinfer1::IPluginCreator {
class RemovePaddingPluginCreator : public nvinfer1::IPluginCreator {
public:
SpecialSlicePluginDynamicCreator();
const char* getPluginName() const TRT_NOEXCEPT override;
const char* getPluginVersion() const TRT_NOEXCEPT override;
const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override;
nvinfer1::IPluginV2* createPlugin(const char* name,
const nvinfer1::PluginFieldCollection* fc)
TRT_NOEXCEPT override;
RemovePaddingPluginCreator() {}
const char* getPluginName() const TRT_NOEXCEPT override {
return "remove_padding_plugin";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* plugin_field)
TRT_NOEXCEPT override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(
const char* name, const void* serial_data,
size_t serial_length) TRT_NOEXCEPT override;
void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override;
const char* getPluginNamespace() const TRT_NOEXCEPT override;
const char* name, void const* serial_data,
size_t serial_length) TRT_NOEXCEPT override {
RemovePaddingPlugin* obj =
new RemovePaddingPlugin(serial_data, serial_length);
obj->setPluginNamespace(name);
return obj;
}
void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const TRT_NOEXCEPT override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(SpecialSlicePluginDynamicCreator);
#endif
REGISTER_TRT_PLUGIN_V2(RemovePaddingPluginCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
SpecialSlicePluginDynamic::SpecialSlicePluginDynamic() {}
SpecialSlicePluginDynamic::SpecialSlicePluginDynamic(void const* serial_data,
size_t serial_length) {}
SpecialSlicePluginDynamic::~SpecialSlicePluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* SpecialSlicePluginDynamic::clone() const
TRT_NOEXCEPT {
return new SpecialSlicePluginDynamic();
}
const char* SpecialSlicePluginDynamic::getPluginType() const TRT_NOEXCEPT {
return "special_slice_plugin";
}
int SpecialSlicePluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; }
int SpecialSlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
size_t SpecialSlicePluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
size_t serialize_size = 0;
return serialize_size;
}
void SpecialSlicePluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT {}
nvinfer1::DimsExprs SpecialSlicePluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT {
nvinfer1::DimsExprs output(inputs[0]);
output.nbDims++;
for (int i = output.nbDims - 1; i > 1; i--) {
output.d[i] = inputs[0].d[i - 1];
}
auto one = expr_builder.constant(1);
output.d[1] = one;
output.d[0] = expr_builder.operation(nvinfer1::DimensionOperation::kSUB,
*inputs[1].d[0], *one);
// remove padding 1
output.nbDims -= 2;
return output;
}
void SpecialSlicePluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT {}
size_t SpecialSlicePluginDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT {
return 0;
}
void SpecialSlicePluginDynamic::destroy() TRT_NOEXCEPT { delete this; }
void SpecialSlicePluginDynamic::terminate() TRT_NOEXCEPT {}
bool SpecialSlicePluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* desc, int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
if (pos == 0) // slice tensor
return (desc[pos].type == nvinfer1::DataType::kHALF &&
desc[pos].format ==
nvinfer1::TensorFormat::kLINEAR); // || desc[pos].type ==
// nvinfer1::DataType::kFLOAT);
if (pos == 1) // cu_seqlen
return (desc[pos].type == nvinfer1::DataType::kINT32 &&
desc[pos].format == nvinfer1::TensorFormat::kLINEAR);
return (desc[pos].type == nvinfer1::DataType::kHALF &&
desc[pos].format ==
nvinfer1::TensorFormat::kLINEAR); // || desc[pos].type ==
// nvinfer1::DataType::kFLOAT);
}
nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
"The index should be equal to 0"));
return input_types[0];
}
template <typename T>
__global__ void SpecialSliceKernel(const T* slice_input,
const int32_t* cu_seqlens, T* output) {
const int hidden = blockDim.x * gridDim.x;
const int hidden_id = blockIdx.x * blockDim.x + threadIdx.x;
const int batch_id = blockIdx.y;
output[batch_id * hidden + hidden_id] =
slice_input[cu_seqlens[batch_id] * hidden + hidden_id];
}
int SpecialSlicePluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT {
auto input_dims = input_desc[0].dims; // (sum(S), hidden, 1, 1)
auto out_dims = output_desc[0].dims; // (batch, hidden, 1, 1)
PADDLE_ENFORCE_EQ(
input_desc[0].type, nvinfer1::DataType::kHALF,
platform::errors::InvalidArgument("Type of input should be half."));
const int32_t hidden = input_dims.d[1];
PADDLE_ENFORCE_EQ(hidden % 128, 0, platform::errors::InvalidArgument(
"hidden should be multiple of 128."));
constexpr int num_threads = 128;
const half* slice_input = static_cast<const half*>(inputs[0]);
const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]);
half* output = static_cast<half*>(outputs[0]);
const int32_t num_blocks_x = hidden / num_threads;
const int32_t num_blocks_y = out_dims.d[0]; // batchs
const dim3 num_blocks(num_blocks_x, num_blocks_y); // blocks
SpecialSliceKernel<<<num_blocks, num_threads, 0, stream>>>(
slice_input, cu_seqlens, output);
return cudaGetLastError() != cudaSuccess;
}
SpecialSlicePluginDynamicCreator::SpecialSlicePluginDynamicCreator() {}
const char* SpecialSlicePluginDynamicCreator::getPluginName() const
TRT_NOEXCEPT {
return "special_slice_plugin";
}
const char* SpecialSlicePluginDynamicCreator::getPluginVersion() const
TRT_NOEXCEPT {
return "1";
}
const nvinfer1::PluginFieldCollection*
SpecialSlicePluginDynamicCreator::getFieldNames() TRT_NOEXCEPT {
return &field_collection_;
}
nvinfer1::IPluginV2* SpecialSlicePluginDynamicCreator::createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT {
return new SpecialSlicePluginDynamic();
}
nvinfer1::IPluginV2* SpecialSlicePluginDynamicCreator::deserializePlugin(
const char* name, const void* serial_data,
size_t serial_length) TRT_NOEXCEPT {
auto plugin = new SpecialSlicePluginDynamic(serial_data, serial_length);
return plugin;
}
void SpecialSlicePluginDynamicCreator::setPluginNamespace(
const char* lib_namespace) TRT_NOEXCEPT {
plugin_namespace_ = lib_namespace;
}
const char* SpecialSlicePluginDynamicCreator::getPluginNamespace() const
TRT_NOEXCEPT {
return plugin_namespace_.c_str();
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
__global__ void TransformerInputConvertKernel(const int64_t* input,
int32_t* output0) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
__shared__ int32_t shared_data;
if (threadIdx.x == static_cast<int>(input[tid])) {
atomicAdd(&shared_data, 1);
}
output0[0] = 0;
output0[blockIdx.x + 1] = shared_data;
__syncthreads();
for (int i = 0; i < blockDim.x; ++i) {
output0[i + 1] += output0[i];
}
}
nvinfer1::DataType TransformerInputConvertPlugin::getOutputDataType(
int index, const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT {
return nvinfer1::DataType::kINT32;
}
nvinfer1::DimsExprs TransformerInputConvertPlugin::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT {
nvinfer1::DimsExprs output_dims{};
output_dims.nbDims = 1;
if (outputIndex == 0) { // PosId
const auto* one = exprBuilder.constant(1);
output_dims.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kSUM,
*inputs[0].d[0], *one);
} else { // MaxSeqlen
output_dims.d[0] = inputs[0].d[1];
}
return output_dims;
}
bool TransformerInputConvertPlugin::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
int nbOutputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(nbInputs, 1,
platform::errors::InvalidArgument("Must have 1 inputs, "
"but got %d input(s). ",
nbInputs));
PADDLE_ENFORCE_EQ(nbOutputs, getNbOutputs(),
platform::errors::InvalidArgument("Must have 2 output, "
"but got %d output(s). ",
nbOutputs));
if (pos == 0) { // input
return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
} else { // output0, output1
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
}
}
void TransformerInputConvertPlugin::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) TRT_NOEXCEPT {}
void TransformerInputConvertPlugin::attachToContext(
cudnnContext* cudnnContext, cublasContext* cublasContext,
nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {}
void TransformerInputConvertPlugin::detachFromContext() TRT_NOEXCEPT {}
void TransformerInputConvertPlugin::terminate() TRT_NOEXCEPT {}
int TransformerInputConvertPlugin::enqueue(
const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT {
const auto input_desc = inputDesc[0];
const int64_t* input = static_cast<const int64_t*>(inputs[0]);
int32_t* output0 = static_cast<int32_t*>(outputs[0]); // PosId
// int32_t* output1 = static_cast<int32_t*>(outputs[1]); // MaxSeqlen
const int32_t num_blocks = input_desc.dims.d[0]; // batchs
const int32_t num_threads = input_desc.dims.d[1]; // max sequnce length
TransformerInputConvertKernel<<<num_blocks, num_threads, 0, stream>>>(
input, output0);
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cassert>
#include <string>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class TransformerInputConvertPlugin : public DynamicPluginTensorRT {
public:
TransformerInputConvertPlugin() {}
TransformerInputConvertPlugin(void const* serial_data, size_t serial_length) {
}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
TransformerInputConvertPlugin* ptr = new TransformerInputConvertPlugin();
return ptr;
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "transformer_input_convert_plugin";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 2; }
int initialize() TRT_NOEXCEPT { return 0; }
void terminate() TRT_NOEXCEPT;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override {
return 0;
}
void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext,
nvinfer1::IGpuAllocator* gpuAllocator)
TRT_NOEXCEPT override;
void detachFromContext() TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(
int index, const nvinfer1::DataType* inputTypes,
int nbInputs) const TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
protected:
size_t getSerializationSize() const TRT_NOEXCEPT override { return 0; }
void serialize(void* buffer) const TRT_NOEXCEPT override {}
};
class TransformerInputConvertPluginCreator : public nvinfer1::IPluginCreator {
public:
TransformerInputConvertPluginCreator() {}
const char* getPluginName() const TRT_NOEXCEPT override {
return "transformer_input_convert_plugin";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* plugin_field)
TRT_NOEXCEPT override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(
const char* name, void const* serial_data,
size_t serial_length) TRT_NOEXCEPT override {
TransformerInputConvertPlugin* obj =
new TransformerInputConvertPlugin(serial_data, serial_length);
obj->setPluginNamespace(name);
return obj;
}
void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const TRT_NOEXCEPT override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
};
REGISTER_TRT_PLUGIN_V2(TransformerInputConvertPluginCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -65,7 +65,7 @@ TEST(PD_Config, gpu_interface) {
&min_shape_ptr, &max_shape_ptr,
&opt_shape_ptr, FALSE);
PD_ConfigDisableTensorRtOPs(config, 1, &ops_name);
PD_ConfigEnableTensorRtOSS(config);
PD_ConfigEnableVarseqlen(config);
bool oss_enabled = PD_ConfigTensorRtOssEnabled(config);
EXPECT_TRUE(oss_enabled);
......
......@@ -210,7 +210,11 @@ std::shared_ptr<paddle_infer::Predictor> InitPredictor() {
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape);
// erinie varlen must be used with oss
config.EnableTensorRtOSS();
config.EnableVarseqlen();
paddle_infer::experimental::InternalUtils::SetTransformerPosid(&config,
input_name2);
paddle_infer::experimental::InternalUtils::SetTransformerMaskid(&config,
input_name3);
return paddle_infer::CreatePredictor(config);
}
......
......@@ -68,7 +68,7 @@ std::shared_ptr<Predictor> InitPredictor() {
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape);
// erinie varlen must be used with oss
config.EnableTensorRtOSS();
config.EnableVarseqlen();
return CreatePredictor(config);
}
......
......@@ -43,7 +43,7 @@ TEST(table_printer, output) {
table.InsertRow({"trt_precision", "fp32"});
table.InsertRow({"enable_dynamic_shape", "true"});
table.InsertRow({"DisableTensorRtOPs", "{}"});
table.InsertRow({"EnableTensorRtOSS", "ON"});
table.InsertRow({"EnableVarseqlen", "ON"});
table.InsertRow({"tensorrt_dla_enabled", "ON"});
table.InsetDivider();
......
......@@ -657,8 +657,9 @@ void BindAnalysisConfig(py::module *m) {
py::arg("disable_trt_plugin_fp16") = false)
.def("tensorrt_dynamic_shape_enabled",
&AnalysisConfig::tensorrt_dynamic_shape_enabled)
.def("enable_tensorrt_oss", &AnalysisConfig::EnableTensorRtOSS)
.def("tensorrt_oss_enabled", &AnalysisConfig::tensorrt_oss_enabled)
.def("enable_tensorrt_varseqlen", &AnalysisConfig::EnableVarseqlen)
.def("tensorrt_varseqlen_enabled",
&AnalysisConfig::tensorrt_varseqlen_enabled)
.def("collect_shape_range_info", &AnalysisConfig::CollectShapeRangeInfo)
.def("shape_range_info_path", &AnalysisConfig::shape_range_info_path)
.def("shape_range_info_collected",
......
......@@ -42,7 +42,7 @@ class InferencePassTest(unittest.TestCase):
self.enable_mkldnn = False
self.enable_mkldnn_bfloat16 = False
self.enable_trt = False
self.enable_tensorrt_oss = True
self.enable_tensorrt_varseqlen = True
self.trt_parameters = None
self.dynamic_shape_params = None
self.enable_lite = False
......@@ -134,8 +134,8 @@ class InferencePassTest(unittest.TestCase):
self.dynamic_shape_params.max_input_shape,
self.dynamic_shape_params.optim_input_shape,
self.dynamic_shape_params.disable_trt_plugin_fp16)
if self.enable_tensorrt_oss:
config.enable_tensorrt_oss()
if self.enable_tensorrt_varseqlen:
config.enable_tensorrt_varseqlen()
elif use_mkldnn:
config.enable_mkldnn()
......
......@@ -46,7 +46,7 @@ class QuantDequantTest(unittest.TestCase):
self.enable_mkldnn = False
self.enable_mkldnn_bfloat16 = False
self.enable_trt = False
self.enable_tensorrt_oss = True
self.enable_tensorrt_varseqlen = True
self.trt_parameters = None
self.dynamic_shape_params = None
self.enable_lite = False
......@@ -184,8 +184,8 @@ class QuantDequantTest(unittest.TestCase):
self.dynamic_shape_params.max_input_shape,
self.dynamic_shape_params.optim_input_shape,
self.dynamic_shape_params.disable_trt_plugin_fp16)
if self.enable_tensorrt_oss:
config.enable_tensorrt_oss()
if self.enable_tensorrt_varseqlen:
config.enable_tensorrt_varseqlen()
elif use_mkldnn:
config.enable_mkldnn()
......
......@@ -179,7 +179,7 @@ def multiclass_nms(bboxes,
class TensorRTMultiClassNMS3Test(InferencePassTest):
def setUp(self):
self.enable_trt = True
self.enable_tensorrt_oss = True
self.enable_tensorrt_varseqlen = True
self.precision = AnalysisConfig.Precision.Float32
self.serialize = False
self.bs = 1
......@@ -291,8 +291,8 @@ class TensorRTMultiClassNMS3Test(InferencePassTest):
self.background = 7
self.run_test()
def test_disable_oss(self):
self.diable_tensorrt_oss = False
def test_disable_varseqlen(self):
self.diable_tensorrt_varseqlen = False
self.run_test()
......
......@@ -25,7 +25,7 @@ from paddle.fluid.core import AnalysisConfig
class TensorRTMultiClassNMSTest(InferencePassTest):
def setUp(self):
self.enable_trt = True
self.enable_tensorrt_oss = True
self.enable_tensorrt_varseqlen = True
self.precision = AnalysisConfig.Precision.Float32
self.serialize = False
self.bs = 1
......@@ -135,8 +135,8 @@ class TensorRTMultiClassNMSTest(InferencePassTest):
self.background = 7
self.run_test()
def test_disable_oss(self):
self.diable_tensorrt_oss = False
def test_disable_varseqlen(self):
self.diable_tensorrt_varseqlen = False
self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册