未验证 提交 178b2440 编写于 作者: W Wangzheee 提交者: GitHub

general_prelayernorm_transformer (#43748)

上级 dbf92d49
...@@ -31,7 +31,8 @@ namespace framework { ...@@ -31,7 +31,8 @@ namespace framework {
namespace ir { namespace ir {
namespace patterns { namespace patterns {
static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name, static PDNode* create_emb_vars(PDPattern* pattern,
const std::string& name,
const std::string& arg, const std::string& arg,
bool is_persist = false) { bool is_persist = false) {
std::unordered_set<std::string> embedding_ops{"lookup_table", std::unordered_set<std::string> embedding_ops{"lookup_table",
...@@ -41,7 +42,8 @@ static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name, ...@@ -41,7 +42,8 @@ static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
if (is_persist) return node->assert_is_persistable_var(); if (is_persist) return node->assert_is_persistable_var();
return node; return node;
} }
static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name, static PDNode* create_emb_out_vars(PDPattern* pattern,
const std::string& name,
const std::string& arg) { const std::string& arg) {
std::unordered_set<std::string> embedding_ops{"lookup_table", std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"}; "lookup_table_v2"};
...@@ -62,6 +64,9 @@ void PrelnEmbedding2Eltwise1Pattern::operator()() { ...@@ -62,6 +64,9 @@ void PrelnEmbedding2Eltwise1Pattern::operator()() {
create_emb_vars(pattern, lookup_table2_w_repr(), "W", true); create_emb_vars(pattern, lookup_table2_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table", std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"}; "lookup_table_v2"};
auto* feed1 = pattern->NewNode(feed1_repr())->assert_is_op("feed");
auto* feed2 = pattern->NewNode(feed2_repr())->assert_is_op("feed");
auto* lookup_table1 = auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops); pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
auto* lookup_table2 = auto* lookup_table2 =
...@@ -74,8 +79,10 @@ void PrelnEmbedding2Eltwise1Pattern::operator()() { ...@@ -74,8 +79,10 @@ void PrelnEmbedding2Eltwise1Pattern::operator()() {
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr()) auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
->assert_is_op_output("elementwise_add"); ->assert_is_op_output("elementwise_add");
feed1->LinksTo({lookup_table1_x});
lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w}) lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out}); .LinksTo({lookup_table1_out});
feed2->LinksTo({lookup_table2_x});
lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w}) lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w})
.LinksTo({lookup_table2_out}); .LinksTo({lookup_table2_out});
eltwise_add->LinksFrom({lookup_table1_out, lookup_table2_out}) eltwise_add->LinksFrom({lookup_table1_out, lookup_table2_out})
...@@ -88,6 +95,8 @@ void PrelnEmbedding1Eltwise1Pattern::operator()() { ...@@ -88,6 +95,8 @@ void PrelnEmbedding1Eltwise1Pattern::operator()() {
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true); create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table", std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"}; "lookup_table_v2"};
auto* feed1 = pattern->NewNode(feed1_repr())->assert_is_op("feed");
auto* lookup_table1 = auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops); pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
auto* lookup_table1_out = auto* lookup_table1_out =
...@@ -101,6 +110,7 @@ void PrelnEmbedding1Eltwise1Pattern::operator()() { ...@@ -101,6 +110,7 @@ void PrelnEmbedding1Eltwise1Pattern::operator()() {
->assert_is_op_output("elementwise_add"); ->assert_is_op_output("elementwise_add");
lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w}) lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out}); .LinksTo({lookup_table1_out});
feed1->LinksTo({lookup_table1_x});
eltwise_add->LinksFrom({lookup_table1_out, eltwise_add_in}) eltwise_add->LinksFrom({lookup_table1_out, eltwise_add_in})
.LinksTo({eltwise_add_out}); .LinksTo({eltwise_add_out});
} }
...@@ -161,10 +171,10 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion( ...@@ -161,10 +171,10 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_w, lookup_table2_w, start_pattern); GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_w, lookup_table2_w, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, start_pattern); GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2, lookup_table2, start_pattern); GET_IR_NODE_FROM_SUBGRAPH(lookup_table2, lookup_table2, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out, GET_IR_NODE_FROM_SUBGRAPH(
start_pattern); lookup_table1_out, lookup_table1_out, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_out, lookup_table2_out, GET_IR_NODE_FROM_SUBGRAPH(
start_pattern); lookup_table2_out, lookup_table2_out, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, start_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, start_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, start_pattern);
if (!IsCompat(subgraph, graph)) { if (!IsCompat(subgraph, graph)) {
...@@ -179,8 +189,12 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion( ...@@ -179,8 +189,12 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
start_pattern_out_node.push_back(eltwise_add_out); start_pattern_out_node.push_back(eltwise_add_out);
std::unordered_set<Node*> rm_nodes; std::unordered_set<Node*> rm_nodes;
rm_nodes.insert({lookup_table1, lookup_table2, lookup_table1_out, rm_nodes.insert({lookup_table1,
lookup_table2_out, eltwise_add, eltwise_add_out}); lookup_table2,
lookup_table1_out,
lookup_table2_out,
eltwise_add,
eltwise_add_out});
start_pattern_remove_nodes.push_back(rm_nodes); start_pattern_remove_nodes.push_back(rm_nodes);
}; };
gpd(graph, handler); gpd(graph, handler);
...@@ -200,8 +214,8 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion( ...@@ -200,8 +214,8 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, second_pattern); GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, second_pattern); GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, second_pattern); GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out, GET_IR_NODE_FROM_SUBGRAPH(
second_pattern); lookup_table1_out, lookup_table1_out, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_in, eltwise_add_in, second_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_in, eltwise_add_in, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, second_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, second_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, second_pattern);
...@@ -236,19 +250,19 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion( ...@@ -236,19 +250,19 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
auto handler3 = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler3 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, skip_layernorm_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(
skip_layernorm_pattern); eltwise_add_out, eltwise_add_out, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, skip_layernorm_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, GET_IR_NODE_FROM_SUBGRAPH(
skip_layernorm_pattern); layer_norm_out, layer_norm_out, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, GET_IR_NODE_FROM_SUBGRAPH(
skip_layernorm_pattern); layer_norm_bias, layer_norm_bias, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale, GET_IR_NODE_FROM_SUBGRAPH(
skip_layernorm_pattern); layer_norm_scale, layer_norm_scale, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, GET_IR_NODE_FROM_SUBGRAPH(
skip_layernorm_pattern); layer_norm_mean, layer_norm_mean, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, GET_IR_NODE_FROM_SUBGRAPH(
skip_layernorm_pattern); layer_norm_variance, layer_norm_variance, skip_layernorm_pattern);
if (!IsCompat(subgraph, graph)) { if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "Pass(PrelnSkipLayerNorm) in op compat failed."; LOG(WARNING) << "Pass(PrelnSkipLayerNorm) in op compat failed.";
return; return;
...@@ -313,7 +327,7 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion( ...@@ -313,7 +327,7 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
embs.push_back(inner_pattern_ins[js[iter]].second->Name()); embs.push_back(inner_pattern_ins[js[iter]].second->Name());
} }
OpDesc new_op_desc; OpDesc new_op_desc(end_patter_layernorms[0]->Op()->Block());
new_op_desc.SetType("fused_preln_embedding_eltwise_layernorm"); new_op_desc.SetType("fused_preln_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids); new_op_desc.SetInput("Ids", ids);
new_op_desc.SetInput("Embs", embs); new_op_desc.SetInput("Embs", embs);
...@@ -433,16 +447,17 @@ void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { ...@@ -433,16 +447,17 @@ void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
bool use_varseqlen = Get<bool>("use_varseqlen"); bool use_varseqlen = Get<bool>("use_varseqlen");
bool with_interleaved = Get<bool>("with_interleaved"); bool with_interleaved = Get<bool>("with_interleaved");
bool with_dynamic_shape = Get<bool>("with_dynamic_shape"); bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
if (!(enable_int8 && use_varseqlen && with_interleaved && std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
with_dynamic_shape)) { std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
VLOG(4) << "preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, " if (!(enable_int8 && use_varseqlen && with_interleaved && pos_id != "" &&
"enable_int8, " mask_id != "" && with_dynamic_shape)) {
VLOG(3) << "preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, "
"enable_int8, set pos_id, set mask_id, "
"use_varseqlen, with_interleaved, with_dynamic_shape. Stop this " "use_varseqlen, with_interleaved, with_dynamic_shape. Stop this "
"pass, " "pass, "
"please reconfig."; "please reconfig.";
return; return;
} }
int fusion_count = int fusion_count =
PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(graph, name_scope_); PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(graph, name_scope_);
if (fusion_count > 0) { if (fusion_count > 0) {
......
...@@ -51,7 +51,8 @@ struct PrelnEmbedding2Eltwise1Pattern : public PatternBase { ...@@ -51,7 +51,8 @@ struct PrelnEmbedding2Eltwise1Pattern : public PatternBase {
: PatternBase(pattern, name_scope, "Prelnembedding2_eltwise1") {} : PatternBase(pattern, name_scope, "Prelnembedding2_eltwise1") {}
void operator()(); void operator()();
PATTERN_DECL_NODE(feed1);
PATTERN_DECL_NODE(feed2);
PATTERN_DECL_NODE(lookup_table1_x); PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table2_x); PATTERN_DECL_NODE(lookup_table2_x);
PATTERN_DECL_NODE(lookup_table1_w); PATTERN_DECL_NODE(lookup_table1_w);
...@@ -81,6 +82,7 @@ struct PrelnEmbedding1Eltwise1Pattern : public PatternBase { ...@@ -81,6 +82,7 @@ struct PrelnEmbedding1Eltwise1Pattern : public PatternBase {
const std::string& name_scope) const std::string& name_scope)
: PatternBase(pattern, name_scope, "Prelnembedding1_eltwise1") {} : PatternBase(pattern, name_scope, "Prelnembedding1_eltwise1") {}
void operator()(); void operator()();
PATTERN_DECL_NODE(feed1);
PATTERN_DECL_NODE(lookup_table1_x); PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table1_w); PATTERN_DECL_NODE(lookup_table1_w);
PATTERN_DECL_NODE(lookup_table1); PATTERN_DECL_NODE(lookup_table1);
......
...@@ -112,15 +112,21 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -112,15 +112,21 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
bool use_varseqlen = Get<bool>("use_varseqlen"); bool use_varseqlen = Get<bool>("use_varseqlen");
bool with_interleaved = Get<bool>("with_interleaved"); bool with_interleaved = Get<bool>("with_interleaved");
bool with_dynamic_shape = Get<bool>("with_dynamic_shape"); bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if (!(enable_int8 && use_varseqlen && with_interleaved && if (!(enable_int8 && use_varseqlen && with_interleaved &&
with_dynamic_shape)) { graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass) &&
VLOG(4) << "preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, " graph->Has(framework::ir::kMultiheadMatmulPass) && pos_id != "" &&
"use_varseqlen, " mask_id != "" && with_dynamic_shape)) {
"with_interleaved, with_dynamic_shape. Stop this pass, please " VLOG(3) << "preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, "
"reconfig. "; "with_interleaved"
"use_varseqlen, preln_embedding_eltwise_layernorm_fuse_pass, "
"trt_multihead_matmul_fuse_pass"
"set pos_id, set mask_id, with_dynamic_shape. Stop this pass, "
"please "
"reconfig.";
return; return;
} }
int found_subgraph_count = 0; int found_subgraph_count = 0;
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -155,17 +161,17 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -155,17 +161,17 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale, GET_IR_NODE_FROM_SUBGRAPH(
fused_pattern); layer_norm_scale, layer_norm_scale, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, GET_IR_NODE_FROM_SUBGRAPH(
fused_pattern); layer_norm_variance, layer_norm_variance, fused_pattern);
std::unordered_set<const Node *> del_node_set; std::unordered_set<const Node *> del_node_set;
// Create an PrelnSkipLayerNorm op node // Create an PrelnSkipLayerNorm op node
OpDesc new_desc; OpDesc new_desc(elementwise->Op()->Block());
new_desc.SetType("preln_skip_layernorm"); new_desc.SetType("preln_skip_layernorm");
// inputs // inputs
...@@ -209,8 +215,8 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -209,8 +215,8 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
found_subgraph_count++; found_subgraph_count++;
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(found_subgraph_count); AddStatis(found_subgraph_count);
} }
......
...@@ -35,6 +35,25 @@ void EmbEltwiseLayernorm::operator()() { ...@@ -35,6 +35,25 @@ void EmbEltwiseLayernorm::operator()() {
emb_elt_layernorm_op->LinksTo({emb_elt_layernorm_out}); emb_elt_layernorm_op->LinksTo({emb_elt_layernorm_out});
} }
void PrelnEmbEltwiseLayernorm::operator()() {
// Create nodes for fused_preln_embedding_eltwise_layernorm.
auto* preln_emb_elt_layernorm_op =
pattern->NewNode(preln_emb_elt_layernorm_op_repr())
->assert_is_op("fused_preln_embedding_eltwise_layernorm");
auto* preln_emb_elt_layernorm_out_0 =
pattern->NewNode(preln_emb_elt_layernorm_out_0_repr())
->assert_is_op_output("fused_preln_embedding_eltwise_layernorm",
"Out_0");
auto* preln_emb_elt_layernorm_out_1 =
pattern->NewNode(preln_emb_elt_layernorm_out_1_repr())
->assert_is_op_output("fused_preln_embedding_eltwise_layernorm",
"Out_1");
// Add links for fused_preln_embedding_eltwise_layernorm op.
preln_emb_elt_layernorm_op->LinksTo(
{preln_emb_elt_layernorm_out_0, preln_emb_elt_layernorm_out_1});
}
void SkipLayernorm::operator()() { void SkipLayernorm::operator()() {
// Create nodes for skip_layernorm. // Create nodes for skip_layernorm.
auto* skip_layernorm_x = pattern->NewNode(skip_layernorm_x_repr()) auto* skip_layernorm_x = pattern->NewNode(skip_layernorm_x_repr())
...@@ -51,6 +70,30 @@ void SkipLayernorm::operator()() { ...@@ -51,6 +70,30 @@ void SkipLayernorm::operator()() {
.LinksTo({skip_layernorm_out}); .LinksTo({skip_layernorm_out});
} }
void PrelnSkipLayernorm::operator()() {
// Create nodes for preln_skip_layernorm.
auto* preln_skip_layernorm_x =
pattern->NewNode(preln_skip_layernorm_x_repr())
->assert_is_op_input("preln_skip_layernorm", "X");
auto* preln_skip_layernorm_y =
pattern->NewNode(preln_skip_layernorm_y_repr())
->assert_is_op_input("preln_skip_layernorm", "Y");
auto* preln_skip_layernorm_op =
pattern->NewNode(preln_skip_layernorm_op_repr())
->assert_is_op("preln_skip_layernorm");
auto* preln_skip_layernorm_out_0 =
pattern->NewNode(preln_skip_layernorm_out_0_repr())
->assert_is_op_output("preln_skip_layernorm", "Out_0");
auto* preln_skip_layernorm_out_1 =
pattern->NewNode(preln_skip_layernorm_out_1_repr())
->assert_is_op_output("preln_skip_layernorm", "Out_1");
// Add links for preln_skip_layernorm op.
preln_skip_layernorm_op
->LinksFrom({preln_skip_layernorm_x, preln_skip_layernorm_y})
.LinksTo({preln_skip_layernorm_out_0, preln_skip_layernorm_out_1});
}
void MultiheadMatmul::operator()() { void MultiheadMatmul::operator()() {
// Create nodes for multihead_matmul. // Create nodes for multihead_matmul.
auto* multihead_matmul_input = auto* multihead_matmul_input =
...@@ -96,10 +139,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -96,10 +139,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid"); std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if (use_varseqlen && pos_id != "" && mask_id != "" && if (use_varseqlen && pos_id != "" && mask_id != "" &&
graph->Has(framework::ir::kEmbEltwiseLayernormPass) && (graph->Has(framework::ir::kEmbEltwiseLayernormPass) ||
graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass)) &&
graph->Has(framework::ir::kMultiheadMatmulPass)) { graph->Has(framework::ir::kMultiheadMatmulPass)) {
VLOG(3) << "start varseqlen remove_padding_recover_padding_pass"; VLOG(3) << "start varseqlen remove_padding_recover_padding_pass";
} else { } else {
VLOG(3) << "remove_padding_recover_padding_pass check failed";
return; return;
} }
...@@ -131,9 +176,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -131,9 +176,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
remove_padding.SetOutput("Out", {remove_padding_out_name}); remove_padding.SetOutput("Out", {remove_padding_out_name});
// set out_threshold for int8 // set out_threshold for int8
if (op_node->Op()->HasAttr("out_threshold")) { if (op_node->Op()->HasAttr("Input_scale")) {
remove_padding.SetAttr("out_threshold", remove_padding.SetAttr("out_threshold",
op_node->Op()->GetAttr("out_threshold")); op_node->Op()->GetAttr("Input_scale"));
} else {
VLOG(3) << "remove_padding_op has not out_threshold, because next op has "
"not Input_scale.";
} }
auto remove_padding_op_node = graph->CreateOpNode(&remove_padding); auto remove_padding_op_node = graph->CreateOpNode(&remove_padding);
...@@ -194,6 +242,15 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -194,6 +242,15 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
if (op_node->Op()->HasAttr("out_threshold")) { if (op_node->Op()->HasAttr("out_threshold")) {
recover_padding.SetAttr("out_threshold", recover_padding.SetAttr("out_threshold",
op_node->Op()->GetAttr("out_threshold")); op_node->Op()->GetAttr("out_threshold"));
} else if (op_node->Op()->HasAttr("out_0_threshold")) {
recover_padding.SetAttr("out_threshold",
op_node->Op()->GetAttr("out_0_threshold"));
} else if (op_node->Op()->HasAttr("out_1_threshold")) {
recover_padding.SetAttr("out_threshold",
op_node->Op()->GetAttr("out_1_threshold"));
} else {
VLOG(3) << "recover_padding_op has not out_threshold, because previous "
"op has not out_*_threshold.";
} }
auto recover_padding_op_node = graph->CreateOpNode(&recover_padding); auto recover_padding_op_node = graph->CreateOpNode(&recover_padding);
...@@ -241,9 +298,11 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -241,9 +298,11 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: " VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"fused_embedding_eltwise_layernorm"; "fused_embedding_eltwise_layernorm";
GET_IR_NODE_FROM_SUBGRAPH(emb_elt_layernorm_op, emb_elt_layernorm_op, GET_IR_NODE_FROM_SUBGRAPH(emb_elt_layernorm_op,
emb_elt_layernorm_op,
fused_embedding_eltwise_layernorm); fused_embedding_eltwise_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(emb_elt_layernorm_out, emb_elt_layernorm_out, GET_IR_NODE_FROM_SUBGRAPH(emb_elt_layernorm_out,
emb_elt_layernorm_out,
fused_embedding_eltwise_layernorm); fused_embedding_eltwise_layernorm);
insert_recover_padding_op(emb_elt_layernorm_op, emb_elt_layernorm_out); insert_recover_padding_op(emb_elt_layernorm_op, emb_elt_layernorm_out);
...@@ -263,12 +322,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -263,12 +322,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: " VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"multihead_matmul"; "multihead_matmul";
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_input, multihead_matmul_input, GET_IR_NODE_FROM_SUBGRAPH(
multihead_matmul); multihead_matmul_input, multihead_matmul_input, multihead_matmul);
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_op, multihead_matmul_op, GET_IR_NODE_FROM_SUBGRAPH(
multihead_matmul); multihead_matmul_op, multihead_matmul_op, multihead_matmul);
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_out, multihead_matmul_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_matmul); multihead_matmul_out, multihead_matmul_out, multihead_matmul);
multihead_matmul_input_shape = multihead_matmul_input->Var()->GetShape(); multihead_matmul_input_shape = multihead_matmul_input->Var()->GetShape();
...@@ -289,14 +348,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -289,14 +348,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: " VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"skip_layernorm"; "skip_layernorm";
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_x, skip_layernorm_x, GET_IR_NODE_FROM_SUBGRAPH(
skip_layernorm); skip_layernorm_x, skip_layernorm_x, skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_y, skip_layernorm_y, GET_IR_NODE_FROM_SUBGRAPH(
skip_layernorm); skip_layernorm_y, skip_layernorm_y, skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_op, skip_layernorm_op, GET_IR_NODE_FROM_SUBGRAPH(
skip_layernorm); skip_layernorm_op, skip_layernorm_op, skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_out, skip_layernorm_out, GET_IR_NODE_FROM_SUBGRAPH(
skip_layernorm); skip_layernorm_out, skip_layernorm_out, skip_layernorm);
std::vector<int64_t> skip_layernorm_x_shape = std::vector<int64_t> skip_layernorm_x_shape =
skip_layernorm_x->Var()->GetShape(); skip_layernorm_x->Var()->GetShape();
...@@ -417,6 +476,86 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -417,6 +476,86 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
}; };
gpd4(graph, handler4); gpd4(graph, handler4);
GraphPatternDetector gpd5;
patterns::PrelnEmbEltwiseLayernorm fused_preln_embedding_eltwise_layernorm(
gpd5.mutable_pattern(), "remove_padding_recover_padding_pass");
fused_preln_embedding_eltwise_layernorm();
auto handler5 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"fused_preln_embedding_eltwise_layernorm";
GET_IR_NODE_FROM_SUBGRAPH(preln_emb_elt_layernorm_op,
preln_emb_elt_layernorm_op,
fused_preln_embedding_eltwise_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(preln_emb_elt_layernorm_out_0,
preln_emb_elt_layernorm_out_0,
fused_preln_embedding_eltwise_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(preln_emb_elt_layernorm_out_1,
preln_emb_elt_layernorm_out_1,
fused_preln_embedding_eltwise_layernorm);
insert_recover_padding_op(preln_emb_elt_layernorm_op,
preln_emb_elt_layernorm_out_0);
insert_recover_padding_op(preln_emb_elt_layernorm_op,
preln_emb_elt_layernorm_out_1);
found_subgraph_count++;
};
gpd5(graph, handler5);
GraphPatternDetector gpd6;
patterns::PrelnSkipLayernorm preln_skip_layernorm(
gpd6.mutable_pattern(), "remove_padding_recover_padding_pass");
preln_skip_layernorm();
auto handler6 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"preln_skip_layernorm";
GET_IR_NODE_FROM_SUBGRAPH(
preln_skip_layernorm_x, preln_skip_layernorm_x, preln_skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(
preln_skip_layernorm_y, preln_skip_layernorm_y, preln_skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(
preln_skip_layernorm_op, preln_skip_layernorm_op, preln_skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(preln_skip_layernorm_out_0,
preln_skip_layernorm_out_0,
preln_skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(preln_skip_layernorm_out_1,
preln_skip_layernorm_out_1,
preln_skip_layernorm);
std::vector<int64_t> skip_layernorm_x_shape =
preln_skip_layernorm_x->Var()->GetShape();
if (skip_layernorm_x_shape.size() != multihead_matmul_input_shape.size()) {
check_flag = false;
VLOG(3) << "Transformer model remove_padding shape check failed, return "
"remove_padding pass.";
return;
}
for (size_t i = 0; i < skip_layernorm_x_shape.size(); ++i) {
if (skip_layernorm_x_shape[i] != multihead_matmul_input_shape[i]) {
check_flag = false;
}
}
if (!check_flag) {
VLOG(3) << "Transformer model remove_padding shape check failed, return "
"remove_padding pass.";
return;
}
insert_remove_padding_op(preln_skip_layernorm_x, preln_skip_layernorm_op);
insert_remove_padding_op(preln_skip_layernorm_y, preln_skip_layernorm_op);
insert_recover_padding_op(preln_skip_layernorm_op,
preln_skip_layernorm_out_0);
insert_recover_padding_op(preln_skip_layernorm_op,
preln_skip_layernorm_out_1);
found_subgraph_count++;
};
gpd6(graph, handler6);
AddStatis(found_subgraph_count); AddStatis(found_subgraph_count);
} }
......
...@@ -41,6 +41,16 @@ struct EmbEltwiseLayernorm : public PatternBase { ...@@ -41,6 +41,16 @@ struct EmbEltwiseLayernorm : public PatternBase {
PATTERN_DECL_NODE(emb_elt_layernorm_out); PATTERN_DECL_NODE(emb_elt_layernorm_out);
}; };
struct PrelnEmbEltwiseLayernorm : public PatternBase {
PrelnEmbEltwiseLayernorm(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "preln_emb_elt_layernorm") {}
void operator()();
PATTERN_DECL_NODE(preln_emb_elt_layernorm_op);
PATTERN_DECL_NODE(preln_emb_elt_layernorm_out_0);
PATTERN_DECL_NODE(preln_emb_elt_layernorm_out_1);
};
struct SkipLayernorm : public PatternBase { struct SkipLayernorm : public PatternBase {
SkipLayernorm(PDPattern *pattern, const std::string &name_scope) SkipLayernorm(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "skip_layernorm") {} : PatternBase(pattern, name_scope, "skip_layernorm") {}
...@@ -53,6 +63,19 @@ struct SkipLayernorm : public PatternBase { ...@@ -53,6 +63,19 @@ struct SkipLayernorm : public PatternBase {
PATTERN_DECL_NODE(skip_layernorm_out); PATTERN_DECL_NODE(skip_layernorm_out);
}; };
struct PrelnSkipLayernorm : public PatternBase {
PrelnSkipLayernorm(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "preln_skip_layernorm") {}
void operator()();
PATTERN_DECL_NODE(preln_skip_layernorm_x);
PATTERN_DECL_NODE(preln_skip_layernorm_y);
PATTERN_DECL_NODE(preln_skip_layernorm_op);
PATTERN_DECL_NODE(preln_skip_layernorm_out_0);
PATTERN_DECL_NODE(preln_skip_layernorm_out_1);
};
struct MultiheadMatmul : public PatternBase { struct MultiheadMatmul : public PatternBase {
MultiheadMatmul(PDPattern *pattern, const std::string &name_scope) MultiheadMatmul(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "multihead_matmul") {} : PatternBase(pattern, name_scope, "multihead_matmul") {}
......
...@@ -51,11 +51,20 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -51,11 +51,20 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
multihead_pattern(); multihead_pattern();
// Create New OpDesc // Create New OpDesc
auto fuse_creater = [&](Node* input0, Node* mul0, Node* mul1, Node* mul2, auto fuse_creater = [&](Node* input0,
Node* mul0_out, Node* mul1_out, Node* mul2_out, Node* mul0,
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* mul1,
Node* eltadd_qk_b, Node* reshape2, Node* mul2,
Node* reshape2_qkv_out, Node* scale, Node* mul0_out,
Node* mul1_out,
Node* mul2_out,
Node* eltadd0_b,
Node* eltadd1_b,
Node* eltadd2_b,
Node* eltadd_qk_b,
Node* reshape2,
Node* reshape2_qkv_out,
Node* scale,
Node* scale_out) { Node* scale_out) {
auto scale_attr = BOOST_GET_CONST(float, scale->Op()->GetAttr("scale")); auto scale_attr = BOOST_GET_CONST(float, scale->Op()->GetAttr("scale"));
// auto scale_bias = BOOST_GET_CONST(float, scale->Op()->GetAttr("bias")); // auto scale_bias = BOOST_GET_CONST(float, scale->Op()->GetAttr("bias"));
...@@ -123,11 +132,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -123,11 +132,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); reshape2_0_out, reshape2_0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); transpose2_0_out, transpose2_0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern);
...@@ -135,21 +144,21 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -135,21 +144,21 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); reshape2_1_out, reshape2_1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); transpose2_1_out, transpose2_1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); reshape2_2_out, reshape2_2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); transpose2_2_out, transpose2_2_out, multihead_pattern);
// nodes need be removed // nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern);
...@@ -172,24 +181,36 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -172,24 +181,36 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); softmax_qk_out, softmax_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); matmul_qkv_out, matmul_qkv_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); reshape2_qkv_out, reshape2_qkv_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); transpose2_qkv, transpose2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); transpose2_qkv_out, transpose2_qkv_out, multihead_pattern);
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, fuse_creater(input0,
eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, reshape2_0, mul0,
reshape2_qkv_out, scale, scale_out); mul1,
mul2,
mul0_out,
mul1_out,
mul2_out,
eltadd0_b,
eltadd1_b,
eltadd2_b,
eltadd_qk_b,
reshape2_0,
reshape2_qkv_out,
scale,
scale_out);
std::unordered_set<const Node*> marked_nodes( std::unordered_set<const Node*> marked_nodes(
{eltadd0, {eltadd0,
...@@ -777,14 +798,30 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, ...@@ -777,14 +798,30 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
multihead_pattern(); multihead_pattern();
// Create New OpDesc // Create New OpDesc
auto fuse_creater = [&](Node* input0, Node* mul0, Node* mul1, Node* mul2, auto fuse_creater = [&](Node* input0,
Node* mul0_out, Node* mul1_out, Node* mul2_out, Node* mul0,
Node* mul0_w, Node* mul1_w, Node* mul2_w, Node* mul1,
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* mul2,
Node* eltadd_qk_b, Node* reshape2, Node* mul0_out,
Node* reshape2_qkv_out, Node* scale, Node* scale_out, Node* mul1_out,
Node* softmax_qk, Node* eltadd0, Node* eltadd1, Node* mul2_out,
Node* eltadd2, Node* matmul_qk, Node* reshape2_qkv) { Node* mul0_w,
Node* mul1_w,
Node* mul2_w,
Node* eltadd0_b,
Node* eltadd1_b,
Node* eltadd2_b,
Node* eltadd_qk_b,
Node* reshape2,
Node* reshape2_qkv_out,
Node* scale,
Node* scale_out,
Node* softmax_qk,
Node* eltadd0,
Node* eltadd1,
Node* eltadd2,
Node* matmul_qk,
Node* reshape2_qkv) {
auto scale_attr = BOOST_GET_CONST(float, scale->Op()->GetAttr("scale")); auto scale_attr = BOOST_GET_CONST(float, scale->Op()->GetAttr("scale"));
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H) // mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
...@@ -842,7 +879,8 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, ...@@ -842,7 +879,8 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
wq_tensor->Resize(combined_w_dims); wq_tensor->Resize(combined_w_dims);
auto* new_combined_w_data = auto* new_combined_w_data =
wq_tensor->mutable_data<float>(platform::CPUPlace()); wq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_w_data, tmp_combined_w_data, memcpy(new_combined_w_data,
tmp_combined_w_data,
sizeof(float) * wq_tensor->numel()); sizeof(float) * wq_tensor->numel());
scope->EraseVars({mul1_w->Name(), mul2_w->Name()}); scope->EraseVars({mul1_w->Name(), mul2_w->Name()});
...@@ -854,15 +892,17 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, ...@@ -854,15 +892,17 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
size_t bias_size = bq_tensor->numel(); size_t bias_size = bq_tensor->numel();
memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size); memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + bias_size, bk_data, memcpy(
sizeof(float) * bias_size); tmp_combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + 2 * bias_size, bv_data, memcpy(tmp_combined_bias_data + 2 * bias_size,
bv_data,
sizeof(float) * bias_size); sizeof(float) * bias_size);
bq_tensor->Resize(combined_bias_dims); bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data = auto* new_combined_bias_data =
bq_tensor->mutable_data<float>(platform::CPUPlace()); bq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_bias_data, tmp_combined_bias_data, memcpy(new_combined_bias_data,
tmp_combined_bias_data,
sizeof(float) * bq_tensor->numel()); sizeof(float) * bq_tensor->numel());
scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()}); scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});
...@@ -944,11 +984,11 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, ...@@ -944,11 +984,11 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); reshape2_0_out, reshape2_0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); transpose2_0_out, transpose2_0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern);
...@@ -956,21 +996,21 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, ...@@ -956,21 +996,21 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); reshape2_1_out, reshape2_1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); transpose2_1_out, transpose2_1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); reshape2_2_out, reshape2_2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); transpose2_2_out, transpose2_2_out, multihead_pattern);
// nodes need be removed // nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern);
...@@ -993,20 +1033,20 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, ...@@ -993,20 +1033,20 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); softmax_qk_out, softmax_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); matmul_qkv_out, matmul_qkv_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); reshape2_qkv_out, reshape2_qkv_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); transpose2_qkv, transpose2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); transpose2_qkv_out, transpose2_qkv_out, multihead_pattern);
// If weights or biases in qkv's fc are shared by multiple multihead_matmul // If weights or biases in qkv's fc are shared by multiple multihead_matmul
// patterns, we do not support this kind of fusion, this pass will not take // patterns, we do not support this kind of fusion, this pass will not take
...@@ -1018,10 +1058,30 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, ...@@ -1018,10 +1058,30 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
if (is_fc_params_shared) { if (is_fc_params_shared) {
return; return;
} }
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w, fuse_creater(input0,
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, mul0,
reshape2_0, reshape2_qkv_out, scale, scale_out, softmax_qk, mul1,
eltadd0, eltadd1, eltadd2, matmul_qk, reshape2_qkv); mul2,
mul0_out,
mul1_out,
mul2_out,
mul0_w,
mul1_w,
mul2_w,
eltadd0_b,
eltadd1_b,
eltadd2_b,
eltadd_qk_b,
reshape2_0,
reshape2_qkv_out,
scale,
scale_out,
softmax_qk,
eltadd0,
eltadd1,
eltadd2,
matmul_qk,
reshape2_qkv);
std::unordered_set<const Node*> marked_nodes({eltadd0, std::unordered_set<const Node*> marked_nodes({eltadd0,
eltadd1, eltadd1,
...@@ -1083,19 +1143,28 @@ void TrtMultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const { ...@@ -1083,19 +1143,28 @@ void TrtMultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const {
int fusion_count = BuildFusionV2(graph, name_scope_, scope); int fusion_count = BuildFusionV2(graph, name_scope_, scope);
if (fusion_count > 0) { if (fusion_count > 0) {
bool use_varseqlen = Get<bool>("use_varseqlen"); bool use_varseqlen = Get<bool>("use_varseqlen");
bool with_interleaved = Get<bool>("with_interleaved");
std::string pos_id = Get<std::string>("tensorrt_transformer_posid"); std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid"); std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if (use_varseqlen && pos_id != "" && mask_id != "") { if (use_varseqlen && pos_id != "" && mask_id != "") {
if (graph->Has(framework::ir::kEmbEltwiseLayernormPass)) { if (graph->Has(framework::ir::kEmbEltwiseLayernormPass) ||
VLOG(3) << "start varseqlen trt_multihead_matmul_fuse_pass_v2"; graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass)) {
if (with_interleaved) {
VLOG(3) << "start interleaved_format "
"varseqlen_trt_multihead_matmul_fuse_pass_v2";
} else { } else {
PADDLE_THROW(platform::errors::Fatal( VLOG(3) << "start varseqlen_trt_multihead_matmul_fuse_pass_v2";
"Use transformer'varseqlen need " }
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen")); } else {
PADDLE_THROW(
platform::errors::Fatal("Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass or "
"preln_embedding_eltwise_layernorm_fuse_"
"pass. please use no_varseqlen"));
} }
} else if (!use_varseqlen && pos_id == "" && mask_id == "") { } else if (!use_varseqlen && pos_id == "" && mask_id == "") {
VLOG(3) << "start no_varseqlen trt_multihead_matmul_fuse_pass_v2"; VLOG(3) << "start no_varseqlen_trt_multihead_matmul_fuse_pass";
} else { } else {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Fatal("Use transformer'varseqlen need config: " platform::errors::Fatal("Use transformer'varseqlen need config: "
...@@ -1251,12 +1320,23 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph, ...@@ -1251,12 +1320,23 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
multihead_pattern(); multihead_pattern();
// Create New OpDesc // Create New OpDesc
auto fuse_creater = [&](Node* input0, Node* mul0, Node* mul1, Node* mul2, auto fuse_creater = [&](Node* input0,
Node* mul0_out, Node* mul1_out, Node* mul2_out, Node* mul0,
Node* mul0_w, Node* mul1_w, Node* mul2_w, Node* mul1,
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* mul2,
Node* eltadd_qk_b, Node* reshape2, Node* mul0_out,
Node* reshape2_qkv_out, Node* matmul_qk) { Node* mul1_out,
Node* mul2_out,
Node* mul0_w,
Node* mul1_w,
Node* mul2_w,
Node* eltadd0_b,
Node* eltadd1_b,
Node* eltadd2_b,
Node* eltadd_qk_b,
Node* reshape2,
Node* reshape2_qkv_out,
Node* matmul_qk) {
auto scale_attr = BOOST_GET_CONST(float, matmul_qk->Op()->GetAttr("alpha")); auto scale_attr = BOOST_GET_CONST(float, matmul_qk->Op()->GetAttr("alpha"));
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H) // mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
...@@ -1314,7 +1394,8 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph, ...@@ -1314,7 +1394,8 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
wq_tensor->Resize(combined_w_dims); wq_tensor->Resize(combined_w_dims);
auto* new_combined_w_data = auto* new_combined_w_data =
wq_tensor->mutable_data<float>(platform::CPUPlace()); wq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_w_data, tmp_combined_w_data, memcpy(new_combined_w_data,
tmp_combined_w_data,
sizeof(float) * wq_tensor->numel()); sizeof(float) * wq_tensor->numel());
scope->EraseVars({mul1_w->Name(), mul2_w->Name()}); scope->EraseVars({mul1_w->Name(), mul2_w->Name()});
...@@ -1326,15 +1407,17 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph, ...@@ -1326,15 +1407,17 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
size_t bias_size = bq_tensor->numel(); size_t bias_size = bq_tensor->numel();
memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size); memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + bias_size, bk_data, memcpy(
sizeof(float) * bias_size); tmp_combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + 2 * bias_size, bv_data, memcpy(tmp_combined_bias_data + 2 * bias_size,
bv_data,
sizeof(float) * bias_size); sizeof(float) * bias_size);
bq_tensor->Resize(combined_bias_dims); bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data = auto* new_combined_bias_data =
bq_tensor->mutable_data<float>(platform::CPUPlace()); bq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_bias_data, tmp_combined_bias_data, memcpy(new_combined_bias_data,
tmp_combined_bias_data,
sizeof(float) * bq_tensor->numel()); sizeof(float) * bq_tensor->numel());
scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()}); scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});
...@@ -1375,31 +1458,31 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph, ...@@ -1375,31 +1458,31 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); reshape2_0_out, reshape2_0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); transpose2_0_out, transpose2_0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); reshape2_1_out, reshape2_1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); transpose2_1_out, transpose2_1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); reshape2_2_out, reshape2_2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); transpose2_2_out, transpose2_2_out, multihead_pattern);
// nodes need be removed // nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern);
...@@ -1422,20 +1505,20 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph, ...@@ -1422,20 +1505,20 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); softmax_qk_out, softmax_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); matmul_qkv_out, matmul_qkv_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); reshape2_qkv_out, reshape2_qkv_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); transpose2_qkv, transpose2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(
multihead_pattern); transpose2_qkv_out, transpose2_qkv_out, multihead_pattern);
// If weights or biases in qkv's fc are shared by multiple multihead_matmul // If weights or biases in qkv's fc are shared by multiple multihead_matmul
// patterns, we do not support this kind of fusion, this pass will not take // patterns, we do not support this kind of fusion, this pass will not take
...@@ -1447,9 +1530,23 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph, ...@@ -1447,9 +1530,23 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
if (is_fc_params_shared) { if (is_fc_params_shared) {
return; return;
} }
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w, fuse_creater(input0,
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, mul0,
reshape2_0, reshape2_qkv_out, matmul_qk); mul1,
mul2,
mul0_out,
mul1_out,
mul2_out,
mul0_w,
mul1_w,
mul2_w,
eltadd0_b,
eltadd1_b,
eltadd2_b,
eltadd_qk_b,
reshape2_0,
reshape2_qkv_out,
matmul_qk);
std::unordered_set<const Node*> marked_nodes({eltadd0, std::unordered_set<const Node*> marked_nodes({eltadd0,
eltadd1, eltadd1,
...@@ -1510,19 +1607,28 @@ void TrtMultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const { ...@@ -1510,19 +1607,28 @@ void TrtMultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const {
int fusion_count = BuildFusionV3(graph, name_scope_, scope); int fusion_count = BuildFusionV3(graph, name_scope_, scope);
if (fusion_count > 0) { if (fusion_count > 0) {
bool use_varseqlen = Get<bool>("use_varseqlen"); bool use_varseqlen = Get<bool>("use_varseqlen");
bool with_interleaved = Get<bool>("with_interleaved");
std::string pos_id = Get<std::string>("tensorrt_transformer_posid"); std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid"); std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if (use_varseqlen && pos_id != "" && mask_id != "") { if (use_varseqlen && pos_id != "" && mask_id != "") {
if (graph->Has(framework::ir::kEmbEltwiseLayernormPass)) { if (graph->Has(framework::ir::kEmbEltwiseLayernormPass) ||
VLOG(3) << "start varseqlen trt_multihead_matmul_fuse_pass_v3"; graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass)) {
if (with_interleaved) {
VLOG(3) << "start interleaved_format "
"varseqlen_trt_multihead_matmul_fuse_pass_v3";
} else { } else {
PADDLE_THROW(platform::errors::Fatal( VLOG(3) << "start varseqlen_trt_multihead_matmul_fuse_pass_v3";
"Use transformer'varseqlen need " }
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen")); } else {
PADDLE_THROW(
platform::errors::Fatal("Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass or "
"preln_embedding_eltwise_layernorm_fuse_"
"pass. please use no_varseqlen"));
} }
} else if (!use_varseqlen && pos_id == "" && mask_id == "") { } else if (!use_varseqlen && pos_id == "" && mask_id == "") {
VLOG(3) << "start no_varseqlen trt_multihead_matmul_fuse_pass_v3"; VLOG(3) << "start no_varseqlen_trt_multihead_matmul_fuse_pass";
} else { } else {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Fatal("Use transformer'varseqlen need config: " platform::errors::Fatal("Use transformer'varseqlen need config: "
......
...@@ -139,12 +139,12 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -139,12 +139,12 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale, GET_IR_NODE_FROM_SUBGRAPH(
fused_pattern); layer_norm_scale, layer_norm_scale, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, GET_IR_NODE_FROM_SUBGRAPH(
fused_pattern); layer_norm_variance, layer_norm_variance, fused_pattern);
std::unordered_set<const Node *> del_node_set; std::unordered_set<const Node *> del_node_set;
...@@ -197,13 +197,15 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -197,13 +197,15 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid"); std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if (use_varseqlen && pos_id != "" && mask_id != "") { if (use_varseqlen && pos_id != "" && mask_id != "") {
if (graph->Has(framework::ir::kEmbEltwiseLayernormPass) && if ((graph->Has(framework::ir::kEmbEltwiseLayernormPass) ||
graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass)) &&
graph->Has(framework::ir::kMultiheadMatmulPass)) { graph->Has(framework::ir::kMultiheadMatmulPass)) {
VLOG(3) << "start varseqlen trt_skip_layernorm_fuse_pass"; VLOG(3) << "start varseqlen trt_skip_layernorm_fuse_pass";
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"Use transformer'varseqlen need " "Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen")); "trt_embedding_eltwise_layernorm_fuse_pass, "
"trt_multihead_matmul_fuse_pass. please use no_varseqlen"));
} }
} else if (!use_varseqlen && pos_id == "" && mask_id == "") { } else if (!use_varseqlen && pos_id == "" && mask_id == "") {
VLOG(3) << "start no_varseqlen trt_skip_layernorm_fuse_pass"; VLOG(3) << "start no_varseqlen trt_skip_layernorm_fuse_pass";
......
...@@ -28,13 +28,21 @@ namespace tensorrt { ...@@ -28,13 +28,21 @@ namespace tensorrt {
class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
public: public:
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope,
bool test_mode) override {
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
VLOG(4) << "convert fluid PrelnEmbEltwiseLayerNorm op to tensorrt layer"; VLOG(4) << "convert fluid PrelnEmbEltwiseLayerNorm op to tensorrt layer";
if (!(engine_->use_varseqlen() && engine_->with_interleaved())) { auto pos_id_name = engine_->tensorrt_transformer_posid();
auto mask_id_name = engine_->tensorrt_transformer_maskid();
bool flag_prelayernorm = engine_->with_interleaved() &&
engine_->use_varseqlen() && pos_id_name != "" &&
mask_id_name != "";
if (!flag_prelayernorm) {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"PrelnErnie: If you want to use oss, must be with interleaved")); "PrelnErnie: If you want to use varseqlen, must be with interleaved, "
"set pos_id_name, set mask_id_name."));
} }
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
bool enable_int8 = op_desc.HasAttr("enable_int8"); bool enable_int8 = op_desc.HasAttr("enable_int8");
...@@ -43,7 +51,6 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -43,7 +51,6 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
platform::errors::Fatal("use with_interleaved must be int8.")); platform::errors::Fatal("use with_interleaved must be int8."));
} }
auto word_id_name = op_desc.Input("WordId").front(); auto word_id_name = op_desc.Input("WordId").front();
auto pos_id_name = op_desc.Input("PosId").front();
engine_->Set("ernie_pos_name", new std::string(pos_id_name)); engine_->Set("ernie_pos_name", new std::string(pos_id_name));
auto sent_id_name = op_desc.Input("SentId").front(); auto sent_id_name = op_desc.Input("SentId").front();
...@@ -51,6 +58,10 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -51,6 +58,10 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
auto pos_emb_name = op_desc.Input("PosEmbedding").front(); auto pos_emb_name = op_desc.Input("PosEmbedding").front();
auto sent_emb_name = op_desc.Input("SentEmbedding").front(); auto sent_emb_name = op_desc.Input("SentEmbedding").front();
engine_->SetITensor("word_id", engine_->GetITensor(word_id_name));
engine_->SetITensor("pos_id", engine_->GetITensor(pos_id_name));
engine_->SetITensor("mask_id", engine_->GetITensor(mask_id_name));
std::vector<std::string> emb_names; std::vector<std::string> emb_names;
emb_names = emb_names =
std::vector<std::string>{word_emb_name, pos_emb_name, sent_emb_name}; std::vector<std::string>{word_emb_name, pos_emb_name, sent_emb_name};
...@@ -81,7 +92,8 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -81,7 +92,8 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
input_embs.push_back(emb_data); input_embs.push_back(emb_data);
emb_sizes.push_back(emb_size); emb_sizes.push_back(emb_size);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
emb_dims.size(), 2, emb_dims.size(),
2,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The fused PrelnEmbEltwiseLayerNorm's emb should be 2 dims.")); "The fused PrelnEmbEltwiseLayerNorm's emb should be 2 dims."));
} }
...@@ -97,23 +109,31 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -97,23 +109,31 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
int output_int8 = 1; int output_int8 = 1;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_num, 3, input_num,
3,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"When using oss and var-len, embedding_eltwise_layernorm op" "When using oss and var-len, embedding_eltwise_layernorm op"
"should have 3 inputs only, but got %d.", "should have 3 inputs only, but got %d.",
input_num)); input_num));
const std::vector<nvinfer1::PluginField> fields{ const std::vector<nvinfer1::PluginField> fields{
{"bert_embeddings_layernorm_beta", bias, {"bert_embeddings_layernorm_beta",
nvinfer1::PluginFieldType::kFLOAT32, static_cast<int32_t>(bias_size)}, bias,
{"bert_embeddings_layernorm_gamma", scale, nvinfer1::PluginFieldType::kFLOAT32,
nvinfer1::PluginFieldType::kFLOAT32, static_cast<int32_t>(scale_size)}, static_cast<int32_t>(bias_size)},
{"bert_embeddings_word_embeddings", input_embs[0], {"bert_embeddings_layernorm_gamma",
scale,
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(scale_size)},
{"bert_embeddings_word_embeddings",
input_embs[0],
nvinfer1::PluginFieldType::kFLOAT32, nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(emb_sizes[0])}, static_cast<int32_t>(emb_sizes[0])},
{"bert_embeddings_token_type_embeddings", input_embs[2], {"bert_embeddings_token_type_embeddings",
input_embs[2],
nvinfer1::PluginFieldType::kFLOAT32, nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(emb_sizes[2])}, static_cast<int32_t>(emb_sizes[2])},
{"bert_embeddings_position_embeddings", input_embs[1], {"bert_embeddings_position_embeddings",
input_embs[1],
nvinfer1::PluginFieldType::kFLOAT32, nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(emb_sizes[1])}, static_cast<int32_t>(emb_sizes[1])},
{"output_fp16", &output_int8, nvinfer1::PluginFieldType::kINT32, 1}, {"output_fp16", &output_int8, nvinfer1::PluginFieldType::kINT32, 1},
...@@ -136,8 +156,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -136,8 +156,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
plugin_inputs.emplace_back( plugin_inputs.emplace_back(
engine_->GetITensor(pos_id_name)); // cu_seqlens, engine_->GetITensor(pos_id_name)); // cu_seqlens,
// eval_placeholder_2 // eval_placeholder_2
auto max_seqlen_tensor = auto max_seqlen_tensor = engine_->GetITensor(mask_id_name);
engine_->GetITensor(engine_->network()->getInput(3)->getName());
auto* shuffle_layer = auto* shuffle_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *max_seqlen_tensor); TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *max_seqlen_tensor);
nvinfer1::Dims shape_dim; nvinfer1::Dims shape_dim;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册