未验证 提交 178b2440 编写于 作者: W Wangzheee 提交者: GitHub

general_prelayernorm_transformer (#43748)

上级 dbf92d49
......@@ -31,7 +31,8 @@ namespace framework {
namespace ir {
namespace patterns {
static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
static PDNode* create_emb_vars(PDPattern* pattern,
const std::string& name,
const std::string& arg,
bool is_persist = false) {
std::unordered_set<std::string> embedding_ops{"lookup_table",
......@@ -41,7 +42,8 @@ static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
if (is_persist) return node->assert_is_persistable_var();
return node;
}
static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name,
static PDNode* create_emb_out_vars(PDPattern* pattern,
const std::string& name,
const std::string& arg) {
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
......@@ -62,6 +64,9 @@ void PrelnEmbedding2Eltwise1Pattern::operator()() {
create_emb_vars(pattern, lookup_table2_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
auto* feed1 = pattern->NewNode(feed1_repr())->assert_is_op("feed");
auto* feed2 = pattern->NewNode(feed2_repr())->assert_is_op("feed");
auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
auto* lookup_table2 =
......@@ -74,8 +79,10 @@ void PrelnEmbedding2Eltwise1Pattern::operator()() {
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
->assert_is_op_output("elementwise_add");
feed1->LinksTo({lookup_table1_x});
lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out});
feed2->LinksTo({lookup_table2_x});
lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w})
.LinksTo({lookup_table2_out});
eltwise_add->LinksFrom({lookup_table1_out, lookup_table2_out})
......@@ -88,6 +95,8 @@ void PrelnEmbedding1Eltwise1Pattern::operator()() {
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
auto* feed1 = pattern->NewNode(feed1_repr())->assert_is_op("feed");
auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
auto* lookup_table1_out =
......@@ -101,6 +110,7 @@ void PrelnEmbedding1Eltwise1Pattern::operator()() {
->assert_is_op_output("elementwise_add");
lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out});
feed1->LinksTo({lookup_table1_x});
eltwise_add->LinksFrom({lookup_table1_out, eltwise_add_in})
.LinksTo({eltwise_add_out});
}
......@@ -161,10 +171,10 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_w, lookup_table2_w, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2, lookup_table2, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out,
start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_out, lookup_table2_out,
start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
lookup_table1_out, lookup_table1_out, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
lookup_table2_out, lookup_table2_out, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, start_pattern);
if (!IsCompat(subgraph, graph)) {
......@@ -179,8 +189,12 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
start_pattern_out_node.push_back(eltwise_add_out);
std::unordered_set<Node*> rm_nodes;
rm_nodes.insert({lookup_table1, lookup_table2, lookup_table1_out,
lookup_table2_out, eltwise_add, eltwise_add_out});
rm_nodes.insert({lookup_table1,
lookup_table2,
lookup_table1_out,
lookup_table2_out,
eltwise_add,
eltwise_add_out});
start_pattern_remove_nodes.push_back(rm_nodes);
};
gpd(graph, handler);
......@@ -200,8 +214,8 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out,
second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
lookup_table1_out, lookup_table1_out, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_in, eltwise_add_in, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, second_pattern);
......@@ -236,19 +250,19 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
auto handler3 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
eltwise_add_out, eltwise_add_out, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_out, layer_norm_out, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_bias, layer_norm_bias, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_scale, layer_norm_scale, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_mean, layer_norm_mean, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_variance, layer_norm_variance, skip_layernorm_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "Pass(PrelnSkipLayerNorm) in op compat failed.";
return;
......@@ -313,7 +327,7 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
embs.push_back(inner_pattern_ins[js[iter]].second->Name());
}
OpDesc new_op_desc;
OpDesc new_op_desc(end_patter_layernorms[0]->Op()->Block());
new_op_desc.SetType("fused_preln_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids);
new_op_desc.SetInput("Embs", embs);
......@@ -433,16 +447,17 @@ void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
bool use_varseqlen = Get<bool>("use_varseqlen");
bool with_interleaved = Get<bool>("with_interleaved");
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
if (!(enable_int8 && use_varseqlen && with_interleaved &&
with_dynamic_shape)) {
VLOG(4) << "preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, "
"enable_int8, "
std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if (!(enable_int8 && use_varseqlen && with_interleaved && pos_id != "" &&
mask_id != "" && with_dynamic_shape)) {
VLOG(3) << "preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, "
"enable_int8, set pos_id, set mask_id, "
"use_varseqlen, with_interleaved, with_dynamic_shape. Stop this "
"pass, "
"please reconfig.";
return;
}
int fusion_count =
PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(graph, name_scope_);
if (fusion_count > 0) {
......
......@@ -51,7 +51,8 @@ struct PrelnEmbedding2Eltwise1Pattern : public PatternBase {
: PatternBase(pattern, name_scope, "Prelnembedding2_eltwise1") {}
void operator()();
PATTERN_DECL_NODE(feed1);
PATTERN_DECL_NODE(feed2);
PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table2_x);
PATTERN_DECL_NODE(lookup_table1_w);
......@@ -81,6 +82,7 @@ struct PrelnEmbedding1Eltwise1Pattern : public PatternBase {
const std::string& name_scope)
: PatternBase(pattern, name_scope, "Prelnembedding1_eltwise1") {}
void operator()();
PATTERN_DECL_NODE(feed1);
PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table1_w);
PATTERN_DECL_NODE(lookup_table1);
......
......@@ -112,15 +112,21 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
bool use_varseqlen = Get<bool>("use_varseqlen");
bool with_interleaved = Get<bool>("with_interleaved");
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if (!(enable_int8 && use_varseqlen && with_interleaved &&
with_dynamic_shape)) {
VLOG(4) << "preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, "
"use_varseqlen, "
"with_interleaved, with_dynamic_shape. Stop this pass, please "
"reconfig. ";
graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass) && pos_id != "" &&
mask_id != "" && with_dynamic_shape)) {
VLOG(3) << "preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, "
"with_interleaved"
"use_varseqlen, preln_embedding_eltwise_layernorm_fuse_pass, "
"trt_multihead_matmul_fuse_pass"
"set pos_id, set mask_id, with_dynamic_shape. Stop this pass, "
"please "
"reconfig.";
return;
}
int found_subgraph_count = 0;
GraphPatternDetector gpd;
......@@ -155,17 +161,17 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale,
fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_scale, layer_norm_scale, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance,
fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_variance, layer_norm_variance, fused_pattern);
std::unordered_set<const Node *> del_node_set;
// Create an PrelnSkipLayerNorm op node
OpDesc new_desc;
OpDesc new_desc(elementwise->Op()->Block());
new_desc.SetType("preln_skip_layernorm");
// inputs
......@@ -209,8 +215,8 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
......
......@@ -35,6 +35,25 @@ void EmbEltwiseLayernorm::operator()() {
emb_elt_layernorm_op->LinksTo({emb_elt_layernorm_out});
}
void PrelnEmbEltwiseLayernorm::operator()() {
// Create nodes for fused_preln_embedding_eltwise_layernorm.
auto* preln_emb_elt_layernorm_op =
pattern->NewNode(preln_emb_elt_layernorm_op_repr())
->assert_is_op("fused_preln_embedding_eltwise_layernorm");
auto* preln_emb_elt_layernorm_out_0 =
pattern->NewNode(preln_emb_elt_layernorm_out_0_repr())
->assert_is_op_output("fused_preln_embedding_eltwise_layernorm",
"Out_0");
auto* preln_emb_elt_layernorm_out_1 =
pattern->NewNode(preln_emb_elt_layernorm_out_1_repr())
->assert_is_op_output("fused_preln_embedding_eltwise_layernorm",
"Out_1");
// Add links for fused_preln_embedding_eltwise_layernorm op.
preln_emb_elt_layernorm_op->LinksTo(
{preln_emb_elt_layernorm_out_0, preln_emb_elt_layernorm_out_1});
}
void SkipLayernorm::operator()() {
// Create nodes for skip_layernorm.
auto* skip_layernorm_x = pattern->NewNode(skip_layernorm_x_repr())
......@@ -51,6 +70,30 @@ void SkipLayernorm::operator()() {
.LinksTo({skip_layernorm_out});
}
void PrelnSkipLayernorm::operator()() {
// Create nodes for preln_skip_layernorm.
auto* preln_skip_layernorm_x =
pattern->NewNode(preln_skip_layernorm_x_repr())
->assert_is_op_input("preln_skip_layernorm", "X");
auto* preln_skip_layernorm_y =
pattern->NewNode(preln_skip_layernorm_y_repr())
->assert_is_op_input("preln_skip_layernorm", "Y");
auto* preln_skip_layernorm_op =
pattern->NewNode(preln_skip_layernorm_op_repr())
->assert_is_op("preln_skip_layernorm");
auto* preln_skip_layernorm_out_0 =
pattern->NewNode(preln_skip_layernorm_out_0_repr())
->assert_is_op_output("preln_skip_layernorm", "Out_0");
auto* preln_skip_layernorm_out_1 =
pattern->NewNode(preln_skip_layernorm_out_1_repr())
->assert_is_op_output("preln_skip_layernorm", "Out_1");
// Add links for preln_skip_layernorm op.
preln_skip_layernorm_op
->LinksFrom({preln_skip_layernorm_x, preln_skip_layernorm_y})
.LinksTo({preln_skip_layernorm_out_0, preln_skip_layernorm_out_1});
}
void MultiheadMatmul::operator()() {
// Create nodes for multihead_matmul.
auto* multihead_matmul_input =
......@@ -96,10 +139,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if (use_varseqlen && pos_id != "" && mask_id != "" &&
graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
(graph->Has(framework::ir::kEmbEltwiseLayernormPass) ||
graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass)) &&
graph->Has(framework::ir::kMultiheadMatmulPass)) {
VLOG(3) << "start varseqlen remove_padding_recover_padding_pass";
} else {
VLOG(3) << "remove_padding_recover_padding_pass check failed";
return;
}
......@@ -131,9 +176,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
remove_padding.SetOutput("Out", {remove_padding_out_name});
// set out_threshold for int8
if (op_node->Op()->HasAttr("out_threshold")) {
if (op_node->Op()->HasAttr("Input_scale")) {
remove_padding.SetAttr("out_threshold",
op_node->Op()->GetAttr("out_threshold"));
op_node->Op()->GetAttr("Input_scale"));
} else {
VLOG(3) << "remove_padding_op has not out_threshold, because next op has "
"not Input_scale.";
}
auto remove_padding_op_node = graph->CreateOpNode(&remove_padding);
......@@ -194,6 +242,15 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
if (op_node->Op()->HasAttr("out_threshold")) {
recover_padding.SetAttr("out_threshold",
op_node->Op()->GetAttr("out_threshold"));
} else if (op_node->Op()->HasAttr("out_0_threshold")) {
recover_padding.SetAttr("out_threshold",
op_node->Op()->GetAttr("out_0_threshold"));
} else if (op_node->Op()->HasAttr("out_1_threshold")) {
recover_padding.SetAttr("out_threshold",
op_node->Op()->GetAttr("out_1_threshold"));
} else {
VLOG(3) << "recover_padding_op has not out_threshold, because previous "
"op has not out_*_threshold.";
}
auto recover_padding_op_node = graph->CreateOpNode(&recover_padding);
......@@ -241,9 +298,11 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"fused_embedding_eltwise_layernorm";
GET_IR_NODE_FROM_SUBGRAPH(emb_elt_layernorm_op, emb_elt_layernorm_op,
GET_IR_NODE_FROM_SUBGRAPH(emb_elt_layernorm_op,
emb_elt_layernorm_op,
fused_embedding_eltwise_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(emb_elt_layernorm_out, emb_elt_layernorm_out,
GET_IR_NODE_FROM_SUBGRAPH(emb_elt_layernorm_out,
emb_elt_layernorm_out,
fused_embedding_eltwise_layernorm);
insert_recover_padding_op(emb_elt_layernorm_op, emb_elt_layernorm_out);
......@@ -263,12 +322,12 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"multihead_matmul";
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_input, multihead_matmul_input,
multihead_matmul);
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_op, multihead_matmul_op,
multihead_matmul);
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_out, multihead_matmul_out,
multihead_matmul);
GET_IR_NODE_FROM_SUBGRAPH(
multihead_matmul_input, multihead_matmul_input, multihead_matmul);
GET_IR_NODE_FROM_SUBGRAPH(
multihead_matmul_op, multihead_matmul_op, multihead_matmul);
GET_IR_NODE_FROM_SUBGRAPH(
multihead_matmul_out, multihead_matmul_out, multihead_matmul);
multihead_matmul_input_shape = multihead_matmul_input->Var()->GetShape();
......@@ -289,14 +348,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"skip_layernorm";
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_x, skip_layernorm_x,
skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_y, skip_layernorm_y,
skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_op, skip_layernorm_op,
skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_out, skip_layernorm_out,
skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(
skip_layernorm_x, skip_layernorm_x, skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(
skip_layernorm_y, skip_layernorm_y, skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(
skip_layernorm_op, skip_layernorm_op, skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(
skip_layernorm_out, skip_layernorm_out, skip_layernorm);
std::vector<int64_t> skip_layernorm_x_shape =
skip_layernorm_x->Var()->GetShape();
......@@ -417,6 +476,86 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
};
gpd4(graph, handler4);
GraphPatternDetector gpd5;
patterns::PrelnEmbEltwiseLayernorm fused_preln_embedding_eltwise_layernorm(
gpd5.mutable_pattern(), "remove_padding_recover_padding_pass");
fused_preln_embedding_eltwise_layernorm();
auto handler5 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"fused_preln_embedding_eltwise_layernorm";
GET_IR_NODE_FROM_SUBGRAPH(preln_emb_elt_layernorm_op,
preln_emb_elt_layernorm_op,
fused_preln_embedding_eltwise_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(preln_emb_elt_layernorm_out_0,
preln_emb_elt_layernorm_out_0,
fused_preln_embedding_eltwise_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(preln_emb_elt_layernorm_out_1,
preln_emb_elt_layernorm_out_1,
fused_preln_embedding_eltwise_layernorm);
insert_recover_padding_op(preln_emb_elt_layernorm_op,
preln_emb_elt_layernorm_out_0);
insert_recover_padding_op(preln_emb_elt_layernorm_op,
preln_emb_elt_layernorm_out_1);
found_subgraph_count++;
};
gpd5(graph, handler5);
GraphPatternDetector gpd6;
patterns::PrelnSkipLayernorm preln_skip_layernorm(
gpd6.mutable_pattern(), "remove_padding_recover_padding_pass");
preln_skip_layernorm();
auto handler6 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"preln_skip_layernorm";
GET_IR_NODE_FROM_SUBGRAPH(
preln_skip_layernorm_x, preln_skip_layernorm_x, preln_skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(
preln_skip_layernorm_y, preln_skip_layernorm_y, preln_skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(
preln_skip_layernorm_op, preln_skip_layernorm_op, preln_skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(preln_skip_layernorm_out_0,
preln_skip_layernorm_out_0,
preln_skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(preln_skip_layernorm_out_1,
preln_skip_layernorm_out_1,
preln_skip_layernorm);
std::vector<int64_t> skip_layernorm_x_shape =
preln_skip_layernorm_x->Var()->GetShape();
if (skip_layernorm_x_shape.size() != multihead_matmul_input_shape.size()) {
check_flag = false;
VLOG(3) << "Transformer model remove_padding shape check failed, return "
"remove_padding pass.";
return;
}
for (size_t i = 0; i < skip_layernorm_x_shape.size(); ++i) {
if (skip_layernorm_x_shape[i] != multihead_matmul_input_shape[i]) {
check_flag = false;
}
}
if (!check_flag) {
VLOG(3) << "Transformer model remove_padding shape check failed, return "
"remove_padding pass.";
return;
}
insert_remove_padding_op(preln_skip_layernorm_x, preln_skip_layernorm_op);
insert_remove_padding_op(preln_skip_layernorm_y, preln_skip_layernorm_op);
insert_recover_padding_op(preln_skip_layernorm_op,
preln_skip_layernorm_out_0);
insert_recover_padding_op(preln_skip_layernorm_op,
preln_skip_layernorm_out_1);
found_subgraph_count++;
};
gpd6(graph, handler6);
AddStatis(found_subgraph_count);
}
......
......@@ -41,6 +41,16 @@ struct EmbEltwiseLayernorm : public PatternBase {
PATTERN_DECL_NODE(emb_elt_layernorm_out);
};
struct PrelnEmbEltwiseLayernorm : public PatternBase {
PrelnEmbEltwiseLayernorm(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "preln_emb_elt_layernorm") {}
void operator()();
PATTERN_DECL_NODE(preln_emb_elt_layernorm_op);
PATTERN_DECL_NODE(preln_emb_elt_layernorm_out_0);
PATTERN_DECL_NODE(preln_emb_elt_layernorm_out_1);
};
struct SkipLayernorm : public PatternBase {
SkipLayernorm(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "skip_layernorm") {}
......@@ -53,6 +63,19 @@ struct SkipLayernorm : public PatternBase {
PATTERN_DECL_NODE(skip_layernorm_out);
};
struct PrelnSkipLayernorm : public PatternBase {
PrelnSkipLayernorm(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "preln_skip_layernorm") {}
void operator()();
PATTERN_DECL_NODE(preln_skip_layernorm_x);
PATTERN_DECL_NODE(preln_skip_layernorm_y);
PATTERN_DECL_NODE(preln_skip_layernorm_op);
PATTERN_DECL_NODE(preln_skip_layernorm_out_0);
PATTERN_DECL_NODE(preln_skip_layernorm_out_1);
};
struct MultiheadMatmul : public PatternBase {
MultiheadMatmul(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "multihead_matmul") {}
......
......@@ -51,11 +51,20 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
multihead_pattern();
// Create New OpDesc
auto fuse_creater = [&](Node* input0, Node* mul0, Node* mul1, Node* mul2,
Node* mul0_out, Node* mul1_out, Node* mul2_out,
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b,
Node* eltadd_qk_b, Node* reshape2,
Node* reshape2_qkv_out, Node* scale,
auto fuse_creater = [&](Node* input0,
Node* mul0,
Node* mul1,
Node* mul2,
Node* mul0_out,
Node* mul1_out,
Node* mul2_out,
Node* eltadd0_b,
Node* eltadd1_b,
Node* eltadd2_b,
Node* eltadd_qk_b,
Node* reshape2,
Node* reshape2_qkv_out,
Node* scale,
Node* scale_out) {
auto scale_attr = BOOST_GET_CONST(float, scale->Op()->GetAttr("scale"));
// auto scale_bias = BOOST_GET_CONST(float, scale->Op()->GetAttr("bias"));
......@@ -123,11 +132,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0_out, reshape2_0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0_out, transpose2_0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern);
......@@ -135,21 +144,21 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_1_out, reshape2_1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_1_out, transpose2_1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_2_out, reshape2_2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_2_out, transpose2_2_out, multihead_pattern);
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern);
......@@ -172,24 +181,36 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk_out, softmax_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv_out, matmul_qkv_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern);
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, reshape2_0,
reshape2_qkv_out, scale, scale_out);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv_out, reshape2_qkv_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_qkv, transpose2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_qkv_out, transpose2_qkv_out, multihead_pattern);
fuse_creater(input0,
mul0,
mul1,
mul2,
mul0_out,
mul1_out,
mul2_out,
eltadd0_b,
eltadd1_b,
eltadd2_b,
eltadd_qk_b,
reshape2_0,
reshape2_qkv_out,
scale,
scale_out);
std::unordered_set<const Node*> marked_nodes(
{eltadd0,
......@@ -777,14 +798,30 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
multihead_pattern();
// Create New OpDesc
auto fuse_creater = [&](Node* input0, Node* mul0, Node* mul1, Node* mul2,
Node* mul0_out, Node* mul1_out, Node* mul2_out,
Node* mul0_w, Node* mul1_w, Node* mul2_w,
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b,
Node* eltadd_qk_b, Node* reshape2,
Node* reshape2_qkv_out, Node* scale, Node* scale_out,
Node* softmax_qk, Node* eltadd0, Node* eltadd1,
Node* eltadd2, Node* matmul_qk, Node* reshape2_qkv) {
auto fuse_creater = [&](Node* input0,
Node* mul0,
Node* mul1,
Node* mul2,
Node* mul0_out,
Node* mul1_out,
Node* mul2_out,
Node* mul0_w,
Node* mul1_w,
Node* mul2_w,
Node* eltadd0_b,
Node* eltadd1_b,
Node* eltadd2_b,
Node* eltadd_qk_b,
Node* reshape2,
Node* reshape2_qkv_out,
Node* scale,
Node* scale_out,
Node* softmax_qk,
Node* eltadd0,
Node* eltadd1,
Node* eltadd2,
Node* matmul_qk,
Node* reshape2_qkv) {
auto scale_attr = BOOST_GET_CONST(float, scale->Op()->GetAttr("scale"));
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
......@@ -842,7 +879,8 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
wq_tensor->Resize(combined_w_dims);
auto* new_combined_w_data =
wq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_w_data, tmp_combined_w_data,
memcpy(new_combined_w_data,
tmp_combined_w_data,
sizeof(float) * wq_tensor->numel());
scope->EraseVars({mul1_w->Name(), mul2_w->Name()});
......@@ -854,15 +892,17 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
size_t bias_size = bq_tensor->numel();
memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + bias_size, bk_data,
sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + 2 * bias_size, bv_data,
memcpy(
tmp_combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + 2 * bias_size,
bv_data,
sizeof(float) * bias_size);
bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data =
bq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_bias_data, tmp_combined_bias_data,
memcpy(new_combined_bias_data,
tmp_combined_bias_data,
sizeof(float) * bq_tensor->numel());
scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});
......@@ -944,11 +984,11 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0_out, reshape2_0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0_out, transpose2_0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern);
......@@ -956,21 +996,21 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_1_out, reshape2_1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_1_out, transpose2_1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_2_out, reshape2_2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_2_out, transpose2_2_out, multihead_pattern);
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern);
......@@ -993,20 +1033,20 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk_out, softmax_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv_out, matmul_qkv_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv_out, reshape2_qkv_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_qkv, transpose2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_qkv_out, transpose2_qkv_out, multihead_pattern);
// If weights or biases in qkv's fc are shared by multiple multihead_matmul
// patterns, we do not support this kind of fusion, this pass will not take
......@@ -1018,10 +1058,30 @@ int TrtMultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
if (is_fc_params_shared) {
return;
}
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w,
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b,
reshape2_0, reshape2_qkv_out, scale, scale_out, softmax_qk,
eltadd0, eltadd1, eltadd2, matmul_qk, reshape2_qkv);
fuse_creater(input0,
mul0,
mul1,
mul2,
mul0_out,
mul1_out,
mul2_out,
mul0_w,
mul1_w,
mul2_w,
eltadd0_b,
eltadd1_b,
eltadd2_b,
eltadd_qk_b,
reshape2_0,
reshape2_qkv_out,
scale,
scale_out,
softmax_qk,
eltadd0,
eltadd1,
eltadd2,
matmul_qk,
reshape2_qkv);
std::unordered_set<const Node*> marked_nodes({eltadd0,
eltadd1,
......@@ -1083,19 +1143,28 @@ void TrtMultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const {
int fusion_count = BuildFusionV2(graph, name_scope_, scope);
if (fusion_count > 0) {
bool use_varseqlen = Get<bool>("use_varseqlen");
bool with_interleaved = Get<bool>("with_interleaved");
std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if (use_varseqlen && pos_id != "" && mask_id != "") {
if (graph->Has(framework::ir::kEmbEltwiseLayernormPass)) {
VLOG(3) << "start varseqlen trt_multihead_matmul_fuse_pass_v2";
if (graph->Has(framework::ir::kEmbEltwiseLayernormPass) ||
graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass)) {
if (with_interleaved) {
VLOG(3) << "start interleaved_format "
"varseqlen_trt_multihead_matmul_fuse_pass_v2";
} else {
PADDLE_THROW(platform::errors::Fatal(
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen"));
VLOG(3) << "start varseqlen_trt_multihead_matmul_fuse_pass_v2";
}
} else {
PADDLE_THROW(
platform::errors::Fatal("Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass or "
"preln_embedding_eltwise_layernorm_fuse_"
"pass. please use no_varseqlen"));
}
} else if (!use_varseqlen && pos_id == "" && mask_id == "") {
VLOG(3) << "start no_varseqlen trt_multihead_matmul_fuse_pass_v2";
VLOG(3) << "start no_varseqlen_trt_multihead_matmul_fuse_pass";
} else {
PADDLE_THROW(
platform::errors::Fatal("Use transformer'varseqlen need config: "
......@@ -1251,12 +1320,23 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
multihead_pattern();
// Create New OpDesc
auto fuse_creater = [&](Node* input0, Node* mul0, Node* mul1, Node* mul2,
Node* mul0_out, Node* mul1_out, Node* mul2_out,
Node* mul0_w, Node* mul1_w, Node* mul2_w,
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b,
Node* eltadd_qk_b, Node* reshape2,
Node* reshape2_qkv_out, Node* matmul_qk) {
auto fuse_creater = [&](Node* input0,
Node* mul0,
Node* mul1,
Node* mul2,
Node* mul0_out,
Node* mul1_out,
Node* mul2_out,
Node* mul0_w,
Node* mul1_w,
Node* mul2_w,
Node* eltadd0_b,
Node* eltadd1_b,
Node* eltadd2_b,
Node* eltadd_qk_b,
Node* reshape2,
Node* reshape2_qkv_out,
Node* matmul_qk) {
auto scale_attr = BOOST_GET_CONST(float, matmul_qk->Op()->GetAttr("alpha"));
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
......@@ -1314,7 +1394,8 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
wq_tensor->Resize(combined_w_dims);
auto* new_combined_w_data =
wq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_w_data, tmp_combined_w_data,
memcpy(new_combined_w_data,
tmp_combined_w_data,
sizeof(float) * wq_tensor->numel());
scope->EraseVars({mul1_w->Name(), mul2_w->Name()});
......@@ -1326,15 +1407,17 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
size_t bias_size = bq_tensor->numel();
memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + bias_size, bk_data,
sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + 2 * bias_size, bv_data,
memcpy(
tmp_combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + 2 * bias_size,
bv_data,
sizeof(float) * bias_size);
bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data =
bq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_bias_data, tmp_combined_bias_data,
memcpy(new_combined_bias_data,
tmp_combined_bias_data,
sizeof(float) * bq_tensor->numel());
scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});
......@@ -1375,31 +1458,31 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_0_out, reshape2_0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_0_out, transpose2_0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_1_out, reshape2_1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_1_out, transpose2_1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_2_out, reshape2_2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_2_out, transpose2_2_out, multihead_pattern);
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern);
......@@ -1422,20 +1505,20 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk_out, softmax_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv_out, matmul_qkv_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_qkv_out, reshape2_qkv_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_qkv, transpose2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_qkv_out, transpose2_qkv_out, multihead_pattern);
// If weights or biases in qkv's fc are shared by multiple multihead_matmul
// patterns, we do not support this kind of fusion, this pass will not take
......@@ -1447,9 +1530,23 @@ int TrtMultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
if (is_fc_params_shared) {
return;
}
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w,
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b,
reshape2_0, reshape2_qkv_out, matmul_qk);
fuse_creater(input0,
mul0,
mul1,
mul2,
mul0_out,
mul1_out,
mul2_out,
mul0_w,
mul1_w,
mul2_w,
eltadd0_b,
eltadd1_b,
eltadd2_b,
eltadd_qk_b,
reshape2_0,
reshape2_qkv_out,
matmul_qk);
std::unordered_set<const Node*> marked_nodes({eltadd0,
eltadd1,
......@@ -1510,19 +1607,28 @@ void TrtMultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const {
int fusion_count = BuildFusionV3(graph, name_scope_, scope);
if (fusion_count > 0) {
bool use_varseqlen = Get<bool>("use_varseqlen");
bool with_interleaved = Get<bool>("with_interleaved");
std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if (use_varseqlen && pos_id != "" && mask_id != "") {
if (graph->Has(framework::ir::kEmbEltwiseLayernormPass)) {
VLOG(3) << "start varseqlen trt_multihead_matmul_fuse_pass_v3";
if (graph->Has(framework::ir::kEmbEltwiseLayernormPass) ||
graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass)) {
if (with_interleaved) {
VLOG(3) << "start interleaved_format "
"varseqlen_trt_multihead_matmul_fuse_pass_v3";
} else {
PADDLE_THROW(platform::errors::Fatal(
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen"));
VLOG(3) << "start varseqlen_trt_multihead_matmul_fuse_pass_v3";
}
} else {
PADDLE_THROW(
platform::errors::Fatal("Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass or "
"preln_embedding_eltwise_layernorm_fuse_"
"pass. please use no_varseqlen"));
}
} else if (!use_varseqlen && pos_id == "" && mask_id == "") {
VLOG(3) << "start no_varseqlen trt_multihead_matmul_fuse_pass_v3";
VLOG(3) << "start no_varseqlen_trt_multihead_matmul_fuse_pass";
} else {
PADDLE_THROW(
platform::errors::Fatal("Use transformer'varseqlen need config: "
......
......@@ -139,12 +139,12 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale,
fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_scale, layer_norm_scale, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance,
fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_variance, layer_norm_variance, fused_pattern);
std::unordered_set<const Node *> del_node_set;
......@@ -197,13 +197,15 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if (use_varseqlen && pos_id != "" && mask_id != "") {
if (graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
if ((graph->Has(framework::ir::kEmbEltwiseLayernormPass) ||
graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass)) &&
graph->Has(framework::ir::kMultiheadMatmulPass)) {
VLOG(3) << "start varseqlen trt_skip_layernorm_fuse_pass";
} else {
PADDLE_THROW(platform::errors::Fatal(
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen"));
"trt_embedding_eltwise_layernorm_fuse_pass, "
"trt_multihead_matmul_fuse_pass. please use no_varseqlen"));
}
} else if (!use_varseqlen && pos_id == "" && mask_id == "") {
VLOG(3) << "start no_varseqlen trt_skip_layernorm_fuse_pass";
......
......@@ -28,13 +28,21 @@ namespace tensorrt {
class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
const framework::Scope& scope,
bool test_mode) override {
#if IS_TRT_VERSION_GE(7000)
VLOG(4) << "convert fluid PrelnEmbEltwiseLayerNorm op to tensorrt layer";
if (!(engine_->use_varseqlen() && engine_->with_interleaved())) {
auto pos_id_name = engine_->tensorrt_transformer_posid();
auto mask_id_name = engine_->tensorrt_transformer_maskid();
bool flag_prelayernorm = engine_->with_interleaved() &&
engine_->use_varseqlen() && pos_id_name != "" &&
mask_id_name != "";
if (!flag_prelayernorm) {
PADDLE_THROW(platform::errors::Fatal(
"PrelnErnie: If you want to use oss, must be with interleaved"));
"PrelnErnie: If you want to use varseqlen, must be with interleaved, "
"set pos_id_name, set mask_id_name."));
}
framework::OpDesc op_desc(op, nullptr);
bool enable_int8 = op_desc.HasAttr("enable_int8");
......@@ -43,7 +51,6 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
platform::errors::Fatal("use with_interleaved must be int8."));
}
auto word_id_name = op_desc.Input("WordId").front();
auto pos_id_name = op_desc.Input("PosId").front();
engine_->Set("ernie_pos_name", new std::string(pos_id_name));
auto sent_id_name = op_desc.Input("SentId").front();
......@@ -51,6 +58,10 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
auto pos_emb_name = op_desc.Input("PosEmbedding").front();
auto sent_emb_name = op_desc.Input("SentEmbedding").front();
engine_->SetITensor("word_id", engine_->GetITensor(word_id_name));
engine_->SetITensor("pos_id", engine_->GetITensor(pos_id_name));
engine_->SetITensor("mask_id", engine_->GetITensor(mask_id_name));
std::vector<std::string> emb_names;
emb_names =
std::vector<std::string>{word_emb_name, pos_emb_name, sent_emb_name};
......@@ -81,7 +92,8 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
input_embs.push_back(emb_data);
emb_sizes.push_back(emb_size);
PADDLE_ENFORCE_EQ(
emb_dims.size(), 2,
emb_dims.size(),
2,
platform::errors::InvalidArgument(
"The fused PrelnEmbEltwiseLayerNorm's emb should be 2 dims."));
}
......@@ -97,23 +109,31 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
int output_int8 = 1;
PADDLE_ENFORCE_EQ(
input_num, 3,
input_num,
3,
platform::errors::InvalidArgument(
"When using oss and var-len, embedding_eltwise_layernorm op"
"should have 3 inputs only, but got %d.",
input_num));
const std::vector<nvinfer1::PluginField> fields{
{"bert_embeddings_layernorm_beta", bias,
nvinfer1::PluginFieldType::kFLOAT32, static_cast<int32_t>(bias_size)},
{"bert_embeddings_layernorm_gamma", scale,
nvinfer1::PluginFieldType::kFLOAT32, static_cast<int32_t>(scale_size)},
{"bert_embeddings_word_embeddings", input_embs[0],
{"bert_embeddings_layernorm_beta",
bias,
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(bias_size)},
{"bert_embeddings_layernorm_gamma",
scale,
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(scale_size)},
{"bert_embeddings_word_embeddings",
input_embs[0],
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(emb_sizes[0])},
{"bert_embeddings_token_type_embeddings", input_embs[2],
{"bert_embeddings_token_type_embeddings",
input_embs[2],
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(emb_sizes[2])},
{"bert_embeddings_position_embeddings", input_embs[1],
{"bert_embeddings_position_embeddings",
input_embs[1],
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(emb_sizes[1])},
{"output_fp16", &output_int8, nvinfer1::PluginFieldType::kINT32, 1},
......@@ -136,8 +156,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
plugin_inputs.emplace_back(
engine_->GetITensor(pos_id_name)); // cu_seqlens,
// eval_placeholder_2
auto max_seqlen_tensor =
engine_->GetITensor(engine_->network()->getInput(3)->getName());
auto max_seqlen_tensor = engine_->GetITensor(mask_id_name);
auto* shuffle_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *max_seqlen_tensor);
nvinfer1::Dims shape_dim;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册