From 2bc91cc5c44558be0bb000f8cdf8301ed1a6de5e Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Tue, 15 Feb 2022 17:10:00 +0800 Subject: [PATCH] [Paddle-Inference] support preln_ernie: add preln_embedding_eltwise_layernorm_fuse_pass, preln_skip_layernorm_fuse_pass (#39508) * support preln_ernie * support preln_ernie --- paddle/fluid/framework/ir/CMakeLists.txt | 2 + paddle/fluid/framework/ir/pass.h | 2 + ...n_embedding_eltwise_layernorm_fuse_pass.cc | 450 ++++++++++++++++++ ...ln_embedding_eltwise_layernorm_fuse_pass.h | 166 +++++++ .../ir/preln_skip_layernorm_fuse_pass.cc | 210 ++++++++ .../ir/preln_skip_layernorm_fuse_pass.h | 86 ++++ .../ir_passes/tensorrt_subgraph_pass.cc | 6 +- .../inference/api/paddle_pass_builder.cc | 34 +- 8 files changed, 938 insertions(+), 18 deletions(-) create mode 100644 paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h create mode 100644 paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 829f43effb..0e1e572a51 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -103,6 +103,8 @@ target_link_libraries(generate_pass pass_desc_proto) if(WITH_TENSORRT) pass_library(trt_map_matmul_to_mul_pass inference) + pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference) + pass_library(preln_skip_layernorm_fuse_pass inference) endif() if(WITH_GPU OR WITH_ROCM) diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 016d0fd4a6..acfe8d53ce 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -47,6 +47,8 @@ constexpr char kPassRecorder[] = "pass_recorder"; constexpr char kEmbEltwiseLayernormPass[] = "embedding_eltwise_layernorm_fuse_pass_flag"; constexpr char kMultiheadMatmulPass[] = "multihead_matmul_fuse_pass_flag"; +constexpr char kPrelnEmbEltwiseLayernormPass[] = + "preln_embedding_eltwise_layernorm_fuse_pass_flag"; class Pass { public: diff --git a/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc new file mode 100644 index 0000000000..ca42a61341 --- /dev/null +++ b/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc @@ -0,0 +1,450 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h" + +#include + +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { +class Node; +} // namespace ir +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name, + const std::string& arg, + bool is_persist = false) { + std::unordered_set embedding_ops{"lookup_table", + "lookup_table_v2"}; + PDNode* node = + pattern->NewNode(name)->assert_is_ops_input(embedding_ops, arg); + if (is_persist) return node->assert_is_persistable_var(); + return node; +} +static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name, + const std::string& arg) { + std::unordered_set embedding_ops{"lookup_table", + "lookup_table_v2"}; + PDNode* node = pattern->NewNode(name) + ->assert_is_only_output_of_ops(embedding_ops) + ->assert_is_op_input("elementwise_add", arg) + ->AsIntermediate(); + return node; +} +void PrelnEmbedding2Eltwise1Pattern::operator()() { + auto* lookup_table1_x = + create_emb_vars(pattern, lookup_table1_x_repr(), "Ids"); + auto* lookup_table2_x = + create_emb_vars(pattern, lookup_table2_x_repr(), "Ids"); + auto* lookup_table1_w = + create_emb_vars(pattern, lookup_table1_w_repr(), "W", true); + auto* lookup_table2_w = + create_emb_vars(pattern, lookup_table2_w_repr(), "W", true); + std::unordered_set embedding_ops{"lookup_table", + "lookup_table_v2"}; + auto* lookup_table1 = + pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops); + auto* lookup_table2 = + pattern->NewNode(lookup_table2_repr())->assert_is_ops(embedding_ops); + auto* lookup_table1_out = + create_emb_out_vars(pattern, lookup_table1_out_repr(), "X"); + auto* lookup_table2_out = + create_emb_out_vars(pattern, lookup_table2_out_repr(), "Y"); + auto* eltwise_add = + pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add"); + auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr()) + ->assert_is_op_output("elementwise_add"); + lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w}) + .LinksTo({lookup_table1_out}); + lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w}) + .LinksTo({lookup_table2_out}); + eltwise_add->LinksFrom({lookup_table1_out, lookup_table2_out}) + .LinksTo({eltwise_add_out}); +} +void PrelnEmbedding1Eltwise1Pattern::operator()() { + auto* lookup_table1_x = + create_emb_vars(pattern, lookup_table1_x_repr(), "Ids"); + auto* lookup_table1_w = + create_emb_vars(pattern, lookup_table1_w_repr(), "W", true); + std::unordered_set embedding_ops{"lookup_table", + "lookup_table_v2"}; + auto* lookup_table1 = + pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops); + auto* lookup_table1_out = + create_emb_out_vars(pattern, lookup_table1_out_repr(), "Y"); + auto* eltwise_add = + pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add"); + auto* eltwise_add_in = pattern->NewNode(eltwise_add_in_repr()) + ->assert_is_op_input("elementwise_add", "X") + ->assert_is_op_output("elementwise_add"); + auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr()) + ->assert_is_op_output("elementwise_add"); + lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w}) + .LinksTo({lookup_table1_out}); + eltwise_add->LinksFrom({lookup_table1_out, eltwise_add_in}) + .LinksTo({eltwise_add_out}); +} +void PrelnSkipLayerNorm::operator()() { + auto* eltwise_add = + pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add"); + auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr()) + ->assert_is_op_output("elementwise_add") + ->assert_is_op_input("layer_norm", "X") + ->assert_is_op_input("elementwise_add", "Y"); + auto* layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto* layer_norm_out = pattern->NewNode(layer_norm_out_repr()) + ->assert_is_op_output("layer_norm", "Y") + ->AsOutput(); + auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto* layer_norm_variance_var = + pattern->NewNode(layer_norm_variance_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Variance"); + eltwise_add->LinksTo({eltwise_add_out}); + layer_norm + ->LinksFrom({eltwise_add_out, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo({layer_norm_out, layer_norm_mean_var, layer_norm_variance_var}); +} + +} // namespace patterns + +int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion( + Graph* graph, const std::string& name_scope + /*const Scope* scope*/) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + std::vector>> start_pattern_in_nodes; + std::vector start_pattern_out_node; + std::vector> start_pattern_remove_nodes; + + // Create pattern. + patterns::PrelnEmbedding2Eltwise1Pattern start_pattern(pattern, + name_scope + "/start"); + start_pattern(); + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_x, lookup_table2_x, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_w, lookup_table2_w, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table2, lookup_table2, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out, + start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_out, lookup_table2_out, + start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, start_pattern); + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) + << "Pass(PrelnEmbedding2Eltwise1Pattern) in op compat failed."; + return; + } + std::vector> ins; + ins.push_back(std::make_pair(lookup_table1_x, lookup_table1_w)); + ins.push_back(std::make_pair(lookup_table2_x, lookup_table2_w)); + start_pattern_in_nodes.push_back(ins); + start_pattern_out_node.push_back(eltwise_add_out); + + std::unordered_set rm_nodes; + rm_nodes.insert({lookup_table1, lookup_table2, lookup_table1_out, + lookup_table2_out, eltwise_add, eltwise_add_out}); + start_pattern_remove_nodes.push_back(rm_nodes); + }; + gpd(graph, handler); + + std::vector> inner_pattern_ins; + std::vector inner_pattern_tmp_in; + std::vector inner_pattern_out; + std::vector> inner_pattern_remove_nodes; + + GraphPatternDetector gpd2; + auto* pattern2 = gpd2.mutable_pattern(); + patterns::PrelnEmbedding1Eltwise1Pattern second_pattern( + pattern2, name_scope + "/second"); + second_pattern(); + auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out, + second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_in, eltwise_add_in, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, second_pattern); + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) + << "Pass(PrelnEmbedding1Eltwise1Pattern) in op compat failed."; + return; + } + auto in = std::make_pair(lookup_table1_x, lookup_table1_w); + inner_pattern_ins.push_back(in); + inner_pattern_tmp_in.push_back(eltwise_add_in); + inner_pattern_out.push_back(eltwise_add_out); + + std::unordered_set rm_nodes; + rm_nodes.insert({lookup_table1, lookup_table1_out, eltwise_add}); + inner_pattern_remove_nodes.push_back(rm_nodes); + }; + gpd2(graph, handler2); + + std::vector end_pattern_elt_out; + std::vector end_pattern_scales; + std::vector end_pattern_biases; + std::vector end_pattern_out; + std::vector end_patter_layernorms; + std::vector end_patter_elementwise; + std::vector> end_pattern_remove_nodes; + GraphPatternDetector gpd3; + auto* pattern3 = gpd3.mutable_pattern(); + patterns::PrelnSkipLayerNorm skip_layernorm_pattern(pattern3, + name_scope + "/third"); + skip_layernorm_pattern(); + auto handler3 = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, + skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, + skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, + skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale, + skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, + skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, + skip_layernorm_pattern); + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "Pass(PrelnSkipLayerNorm) in op compat failed."; + return; + } + end_pattern_elt_out.push_back(eltwise_add_out); + std::unordered_set rm_nodes; + rm_nodes.insert({layer_norm, layer_norm_mean, layer_norm_variance}); + end_pattern_remove_nodes.push_back(rm_nodes); + end_pattern_biases.push_back(layer_norm_bias); + end_pattern_scales.push_back(layer_norm_scale); + end_pattern_out.push_back(layer_norm_out); + end_patter_layernorms.push_back(layer_norm); + end_patter_elementwise.push_back(eltwise_add); + }; + gpd3(graph, handler3); + + if (start_pattern_in_nodes.empty() || end_pattern_elt_out.empty()) { + return 0; + } + // only reserve the subgraphs that in connected domains. + int fusion_count = 0; + // fusion_id for (i, k, js) + std::vector>>> + fusion_ids; + for (size_t i = 0; i < start_pattern_in_nodes.size(); ++i) { + Node* tmp = start_pattern_out_node[i]; + Node* old_tmp = nullptr; + // get correct inner pattern node order. + std::vector js; + while (tmp != old_tmp) { + old_tmp = tmp; + for (size_t j = 0; j < inner_pattern_tmp_in.size(); ++j) { + if (inner_pattern_tmp_in[j] == tmp) { + tmp = inner_pattern_out[j]; + js.push_back(j); + break; + } + } + } + + for (size_t k = 0; k < end_pattern_elt_out.size(); ++k) { + if (tmp == end_pattern_elt_out[k]) { + fusion_ids.push_back(std::make_pair(i, std::make_pair(k, js))); + break; + } + } + } + + for (size_t num = 0; num < fusion_ids.size(); ++num) { + int i = fusion_ids[num].first; + int k = fusion_ids[num].second.first; + std::vector js = fusion_ids[num].second.second; + + std::vector ids; + std::vector embs; + for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) { + ids.push_back(start_pattern_in_nodes[i][iter].first->Name()); + embs.push_back(start_pattern_in_nodes[i][iter].second->Name()); + } + for (size_t iter = 0; iter < js.size(); ++iter) { + ids.push_back(inner_pattern_ins[js[iter]].first->Name()); + embs.push_back(inner_pattern_ins[js[iter]].second->Name()); + } + + OpDesc new_op_desc; + new_op_desc.SetType("fused_preln_embedding_eltwise_layernorm"); + new_op_desc.SetInput("Ids", ids); + new_op_desc.SetInput("Embs", embs); + new_op_desc.SetInput("WordId", {ids[0]}); + new_op_desc.SetInput("PosId", {ids[1]}); + if (ids.size() > 2) { + new_op_desc.SetInput("SentId", {ids[2]}); + } + + new_op_desc.SetInput("WordEmbedding", {embs[0]}); + new_op_desc.SetInput("PosEmbedding", {embs[1]}); + if (embs.size() > 2) { + new_op_desc.SetInput("SentEmbedding", {embs[2]}); + } + + new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()}); + new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()}); + new_op_desc.SetOutput("Out_0", {end_pattern_out[k]->Name()}); + new_op_desc.SetOutput("Out_1", {inner_pattern_out[k]->Name()}); + new_op_desc.SetAttr("epsilon", + end_patter_layernorms[k]->Op()->GetAttr("epsilon")); + + if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold") && + end_patter_elementwise[k]->Op()->HasAttr("out_threshold")) { + new_op_desc.SetAttr("enable_int8", true); + new_op_desc.SetAttr( + "out_0_threshold", + end_patter_layernorms[k]->Op()->GetAttr("out_threshold")); + new_op_desc.SetAttr( + "out_1_threshold", + end_patter_elementwise[k]->Op()->GetAttr("out_threshold")); + } + + auto* preln_embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc); + + for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) { + IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first, + preln_embedding_eltwise_layernorm); + IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second, + preln_embedding_eltwise_layernorm); + } + for (size_t iter = 0; iter < js.size(); ++iter) { + IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first, + preln_embedding_eltwise_layernorm); + IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second, + preln_embedding_eltwise_layernorm); + } + IR_NODE_LINK_TO(end_pattern_biases[k], preln_embedding_eltwise_layernorm); + IR_NODE_LINK_TO(end_pattern_scales[k], preln_embedding_eltwise_layernorm); + IR_NODE_LINK_TO(preln_embedding_eltwise_layernorm, end_pattern_out[k]); + IR_NODE_LINK_TO(preln_embedding_eltwise_layernorm, inner_pattern_out[k]); + + // Remove unneeded nodes. + std::unordered_set marked_nodes; + marked_nodes.insert(start_pattern_remove_nodes[i].begin(), + start_pattern_remove_nodes[i].end()); + marked_nodes.insert(end_pattern_remove_nodes[k].begin(), + end_pattern_remove_nodes[k].end()); + for (size_t iter = 0; iter < js.size(); ++iter) { + marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(), + inner_pattern_remove_nodes[js[iter]].end()); + } + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + } + + return fusion_count; +} + +PrelnEmbeddingEltwiseLayerNormFusePass:: + PrelnEmbeddingEltwiseLayerNormFusePass() { + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .End(); + + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); +} + +void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + int fusion_count = + PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(graph, name_scope_); + if (fusion_count > 0) { + graph->Set(kPrelnEmbEltwiseLayernormPass, new bool(true)); + } + AddStatis(fusion_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(preln_embedding_eltwise_layernorm_fuse_pass, + paddle::framework::ir::PrelnEmbeddingEltwiseLayerNormFusePass); +REGISTER_PASS_CAPABILITY(preln_embedding_eltwise_layernorm_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("lookup_table", 1) + .LE("lookup_table_v2", 1) + .LE("elementweise_add", 1)); diff --git a/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h new file mode 100644 index 0000000000..1ccc6c85d4 --- /dev/null +++ b/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h @@ -0,0 +1,166 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { +class Graph; +} // namespace ir +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +// detect start pattern. +// +// in_var emb in_var emb +// | | | | +// lookup_table lookup_table +// | | +// lkt_var lkt_var +// \ / +// elementwise_add +// | +// elt_out_var +// +struct PrelnEmbedding2Eltwise1Pattern : public PatternBase { + PrelnEmbedding2Eltwise1Pattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, "Prelnembedding2_eltwise1") {} + + void operator()(); + + PATTERN_DECL_NODE(lookup_table1_x); + PATTERN_DECL_NODE(lookup_table2_x); + PATTERN_DECL_NODE(lookup_table1_w); + PATTERN_DECL_NODE(lookup_table2_w); + PATTERN_DECL_NODE(lookup_table1); + PATTERN_DECL_NODE(lookup_table2); + PATTERN_DECL_NODE(lookup_table1_out); + PATTERN_DECL_NODE(lookup_table2_out); + PATTERN_DECL_NODE(eltwise_add); + PATTERN_DECL_NODE(eltwise_add_out); +}; + +// detect repeats inner pattern +// +// elt_out_var in_var emb +// \ | | +// \ lookup_table +// \ | +// \ lkt_var +// \ / +// elementwise_add +// | | +// elementwise_add elt_out_var +// +struct PrelnEmbedding1Eltwise1Pattern : public PatternBase { + PrelnEmbedding1Eltwise1Pattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, "Prelnembedding1_eltwise1") {} + void operator()(); + PATTERN_DECL_NODE(lookup_table1_x); + PATTERN_DECL_NODE(lookup_table1_w); + PATTERN_DECL_NODE(lookup_table1); + PATTERN_DECL_NODE(lookup_table1_out); + PATTERN_DECL_NODE(eltwise_add_in); + PATTERN_DECL_NODE(eltwise_add); + PATTERN_DECL_NODE(eltwise_add_out); +}; + +// detect end pattern +// +// elementwise_add +// | | +// | elt_out_var +// | scale | bias +// | \ | / +// elementwise_add layer_norm +// +struct PrelnSkipLayerNorm : public PatternBase { + PrelnSkipLayerNorm(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "Prelnskip_layernorm") {} + void operator()(); + PATTERN_DECL_NODE(eltwise_add); + PATTERN_DECL_NODE(eltwise_add_out); + PATTERN_DECL_NODE(layer_norm); + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_out); + // Delete the mean and var nodes in the graph. + PATTERN_DECL_NODE(layer_norm_mean); + PATTERN_DECL_NODE(layer_norm_variance); +}; +} // namespace patterns + +// The PrelnEmbeddingEltwiseLayerNormFusePass detect the following pattern: +// +// inputs operator output +// -------------------------------------------------------------------- +// (word, weights_0) lookup_table -> word_emb +// (pos, weights_1) lookup_table -> pos_emb +// (sent, weights_2) lookup_table -> sent_emb +// (word_emb, pos_emb) elementweise_add -> elementwise_out_0 +// (elemtwise_out_0, sent_emb) elementweise_add -> elementwise_out_1 +// (elementwise_out_1, scale, bias) layer_norm -> layer_norm_out +// +// and then convert the corresponding subgraph to: +// +// (word, pos, sent, weights_0, weights_1, weights_2, +// scale, baias) Prelnembedding_eltwise_layernorm -> layer_norm_out + +// elementwise_add_out +// +// +// in_var emb_var in_var emb_var in_var emb_var in_var emb_var +// | | | | | | | | +// lookup_table lookup_table lookup_table ... lookup_table +// | | | | +// lkt_var lkt_var lkt_var lkt_var +// \ / | ... | +// elementwise_add | | +// \ / | +// elementwise_add | +// | | +// elt_var / +// \ / +// elementwise_add +// | | +// elementwise_add layer_norm + +class PrelnEmbeddingEltwiseLayerNormFusePass : public FusePassBase { + public: + PrelnEmbeddingEltwiseLayerNormFusePass(); + virtual ~PrelnEmbeddingEltwiseLayerNormFusePass() {} + + protected: + void ApplyImpl(Graph* graph) const; + int BuildFusion(Graph* graph, const std::string& name_scope + /*const Scope* scope*/) const; + const std::string name_scope_{"preln_embedding_eltwise_layernorm_fuse"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc new file mode 100644 index 0000000000..1b7b82cbca --- /dev/null +++ b/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc @@ -0,0 +1,210 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.h" + +#include + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { +class Node; +} // namespace ir +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct PrelnSkipLayerNorm : public PatternBase { + PrelnSkipLayerNorm(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "preln_skip_layernorm") {} + + void operator()(PDNode *x, PDNode *y); + + // declare operator node's name + PATTERN_DECL_NODE(fused_skipe_layernorm); + PATTERN_DECL_NODE(elementwise); + PATTERN_DECL_NODE(layer_norm); + // declare variable node's name + PATTERN_DECL_NODE( + elementwise_out); // (elementwise_input_x,elementwise_input_y) -> + // elementwise_out + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(layer_norm_mean); + PATTERN_DECL_NODE(layer_norm_variance); +}; + +void PrelnSkipLayerNorm::operator()(PDNode *x, PDNode *y) { + // Create nodes for elementwise add op. + x->assert_is_op_input("elementwise_add", "X"); + y->assert_is_op_input("elementwise_add", "Y"); + auto *elementwise = + pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add"); + auto *elementwise_out_var = pattern->NewNode(elementwise_out_repr()) + ->assert_is_op_output("elementwise_add") + ->assert_is_op_input("layer_norm", "X") + ->assert_is_op_input("elementwise_add", "Y"); + + // Add links for elementwise_add op. + elementwise->LinksFrom({x, y}).LinksTo({elementwise_out_var}); + + // Create nodes for layer_norm op. + auto *layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto *layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto *layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + + auto *layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Y"); + auto *layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto *layer_norm_variance_var = + pattern->NewNode(layer_norm_variance_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Variance"); + + // Add links for layer_norm op. + layer_norm + ->LinksFrom( + {elementwise_out_var, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo( + {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var}); +} + +} // namespace patterns + +void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + FusePassBase::Init("preln_skip_layernorm_fuse", graph); + int found_subgraph_count = 0; + + GraphPatternDetector gpd; + auto *x = gpd.mutable_pattern() + ->NewNode("preln_skip_layernorm_fuse/x") + ->AsInput() + ->assert_is_op_input("elementwise_add", "X") + ->assert_var_not_persistable(); + auto *y = gpd.mutable_pattern() + ->NewNode("preln_skip_layernorm_fuse/y") + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y") + ->assert_var_not_persistable(); + patterns::PrelnSkipLayerNorm fused_pattern(gpd.mutable_pattern(), + "preln_skip_layernorm_fuse"); + fused_pattern(x, y); + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *graph) { + if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) { + LOG(WARNING) << "The subgraph is empty."; + return; + } + + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "preln_skip_layernorm pass in op compat failed."; + return; + } + + VLOG(4) << "handle PrelnSkipLayerNorm fuse"; + GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale, + fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, + fused_pattern); + + std::unordered_set del_node_set; + + // Create an PrelnSkipLayerNorm op node + OpDesc new_desc; + new_desc.SetType("preln_skip_layernorm"); + + // inputs + new_desc.SetInput("X", {subgraph.at(x)->Name()}); + new_desc.SetInput("Y", {subgraph.at(y)->Name()}); + new_desc.SetInput("Scale", {layer_norm_scale->Name()}); + new_desc.SetInput("Bias", {layer_norm_bias->Name()}); + + if (elementwise->Op()->HasAttr("out_threshold") && + layer_norm->Op()->HasAttr("out_threshold")) { + new_desc.SetAttr("enable_int8", true); + new_desc.SetAttr("out_0_threshold", + layer_norm->Op()->GetAttr("out_threshold")); + new_desc.SetAttr("out_1_threshold", + elementwise->Op()->GetAttr("out_threshold")); + } + + // outputs + new_desc.SetOutput("Out_0", {layer_norm_out->Name()}); + new_desc.SetOutput("Out_1", {elementwise_out->Name()}); + + // attrs + new_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon")); + new_desc.SetAttr("begin_norm_axis", + layer_norm->Op()->GetAttr("begin_norm_axis")); + + auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied. + + del_node_set.insert(elementwise); + del_node_set.insert(layer_norm); + del_node_set.insert(layer_norm_mean); + del_node_set.insert(layer_norm_variance); + GraphSafeRemoveNodes(graph, del_node_set); + + IR_NODE_LINK_TO(subgraph.at(x), fused_node); + IR_NODE_LINK_TO(subgraph.at(y), fused_node); + IR_NODE_LINK_TO(layer_norm_scale, fused_node); + IR_NODE_LINK_TO(layer_norm_bias, fused_node); + IR_NODE_LINK_TO(fused_node, layer_norm_out); + IR_NODE_LINK_TO(fused_node, elementwise_out); + + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(preln_skip_layernorm_fuse_pass, + paddle::framework::ir::PrelnSkipLayerNormFusePass); +REGISTER_PASS_CAPABILITY(preln_skip_layernorm_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("elementwise_add", 1) + .EQ("layer_norm", 0)); diff --git a/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.h new file mode 100644 index 0000000000..52447bfd8d --- /dev/null +++ b/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +// | | | | +// other_op1 other_op2 other_op1 other_op2 +// | | fuse \ / +// |------elementwise_add -> skip_layernorm +// | | | | +// other_op4 layer_norm other_op4 other_op3 +// | +// other_op3 +class Graph; + +class PrelnSkipLayerNormFusePass : public FusePassBase { + public: + PrelnSkipLayerNormFusePass() { + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({0, -1}) + .End(); + + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); + } + + virtual ~PrelnSkipLayerNormFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 904baebcb0..e4fc52b6fa 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -379,8 +379,10 @@ void TensorRtSubgraphPass::CreateTensorRTOp( trt_engine->SetUseInspector(Get("use_inspector")); trt_engine->SetWithErnie( - graph->Has(framework::ir::kEmbEltwiseLayernormPass) && - graph->Has(framework::ir::kMultiheadMatmulPass)); + (graph->Has(framework::ir::kEmbEltwiseLayernormPass) && + graph->Has(framework::ir::kMultiheadMatmulPass)) || + (graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass) && + graph->Has(framework::ir::kMultiheadMatmulPass))); if (use_static_engine) { trt_engine_serialized_data = GetTrtEngineSerializedData( diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 66b27b2903..313e1f2fae 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -82,22 +82,24 @@ const std::vector kTRTSubgraphPasses({ "quant_conv2d_dequant_fuse_pass", // "delete_quant_dequant_op_pass", // "delete_quant_dequant_filter_op_pass", // - // "fc_fuse_pass", // - "simplify_with_basic_ops_pass", // - "embedding_eltwise_layernorm_fuse_pass", // - "multihead_matmul_fuse_pass_v2", // - "multihead_matmul_fuse_pass_v3", // - "skip_layernorm_fuse_pass", // - "conv_bn_fuse_pass", // - "unsqueeze2_eltwise_fuse_pass", // - "trt_squeeze2_matmul_fuse_pass", // - "trt_reshape2_matmul_fuse_pass", // - "trt_flatten2_matmul_fuse_pass", // - "trt_map_matmul_v2_to_mul_pass", // - "trt_map_matmul_v2_to_matmul_pass", // - "trt_map_matmul_to_mul_pass", // - "fc_fuse_pass", // - "conv_elementwise_add_fuse_pass", // + // "fc_fuse_pass", // + "simplify_with_basic_ops_pass", // + "embedding_eltwise_layernorm_fuse_pass", // + "preln_embedding_eltwise_layernorm_fuse_pass", // + "multihead_matmul_fuse_pass_v2", // + "multihead_matmul_fuse_pass_v3", // + "skip_layernorm_fuse_pass", // + "preln_skip_layernorm_fuse_pass", // + "conv_bn_fuse_pass", // + "unsqueeze2_eltwise_fuse_pass", // + "trt_squeeze2_matmul_fuse_pass", // + "trt_reshape2_matmul_fuse_pass", // + "trt_flatten2_matmul_fuse_pass", // + "trt_map_matmul_v2_to_mul_pass", // + "trt_map_matmul_v2_to_matmul_pass", // + "trt_map_matmul_to_mul_pass", // + "fc_fuse_pass", // + "conv_elementwise_add_fuse_pass", // "add_support_int8_pass", "tensorrt_subgraph_pass", // "conv_bn_fuse_pass", // -- GitLab