diff --git a/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc index 48f79e63b4f0ea51df27695943690c1c36727e93..0f6421134c21655b9ffb4313d3459541d59a659e 100644 --- a/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc @@ -136,8 +136,12 @@ void SkipLayerNorm::operator()() { ->LinksFrom({eltwise_add_out, layer_norm_bias_var, layer_norm_scale_var}) .LinksTo({layer_norm_out, layer_norm_mean_var, layer_norm_variance_var}); } -static int BuildFusion(Graph* graph, const std::string& name_scope - /*const Scope* scope*/) { + +} // namespace patterns + +int EmbeddingEltwiseLayerNormFusePass::BuildFusion( + Graph* graph, const std::string& name_scope + /*const Scope* scope*/) const { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); @@ -146,7 +150,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope std::vector> start_pattern_remove_nodes; // Create pattern. - Embedding2Eltwise1Pattern start_pattern(pattern, name_scope + "/start"); + patterns::Embedding2Eltwise1Pattern start_pattern(pattern, + name_scope + "/start"); start_pattern(); auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { @@ -162,6 +167,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope start_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, start_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, start_pattern); + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "Pass(Embedding2Eltwise1Pattern) in op compat failed."; + return; + } std::vector> ins; ins.push_back(std::make_pair(lookup_table1_x, lookup_table1_w)); ins.push_back(std::make_pair(lookup_table2_x, lookup_table2_w)); @@ -182,7 +191,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope GraphPatternDetector gpd2; auto* pattern2 = gpd2.mutable_pattern(); - Embedding1Eltwise1Pattern second_pattern(pattern2, name_scope + "/second"); + patterns::Embedding1Eltwise1Pattern second_pattern(pattern2, + name_scope + "/second"); second_pattern(); auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { @@ -194,6 +204,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_in, eltwise_add_in, second_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, second_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, second_pattern); + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "Pass(Embedding1Eltwise1Pattern) in op compat failed."; + return; + } auto in = std::make_pair(lookup_table1_x, lookup_table1_w); inner_pattern_ins.push_back(in); inner_pattern_tmp_in.push_back(eltwise_add_in); @@ -214,7 +228,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope std::vector> end_pattern_remove_nodes; GraphPatternDetector gpd3; auto* pattern3 = gpd3.mutable_pattern(); - SkipLayerNorm skip_layernorm_pattern(pattern3, name_scope + "/third"); + patterns::SkipLayerNorm skip_layernorm_pattern(pattern3, + name_scope + "/third"); skip_layernorm_pattern(); auto handler3 = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { @@ -232,6 +247,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope skip_layernorm_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, skip_layernorm_pattern); + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "Pass(SkipLayerNorm) in op compat failed."; + return; + } end_pattern_elt_out.push_back(eltwise_add_out); std::unordered_set rm_nodes; rm_nodes.insert({layer_norm, layer_norm_mean, layer_norm_variance}); @@ -349,11 +368,53 @@ static int BuildFusion(Graph* graph, const std::string& name_scope return fusion_count; } -} // namespace patterns +EmbeddingEltwiseLayerNormFusePass::EmbeddingEltwiseLayerNormFusePass() { + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({0, -1}) + .End(); + + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); +} void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { FusePassBase::Init(name_scope_, graph); - int fusion_count = patterns::BuildFusion(graph, name_scope_); + int fusion_count = + EmbeddingEltwiseLayerNormFusePass::BuildFusion(graph, name_scope_); if (fusion_count > 0) { graph->Set(kEmbEltwiseLayernormPass, new bool(true)); } diff --git a/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.h index 25049d7468b152e72ad5f32fb38d9204f7219dff..fac9b49e886cb3ed55992cffe2c90c8fa5607dba 100644 --- a/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.h +++ b/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.h @@ -19,8 +19,6 @@ #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" namespace paddle { namespace framework { @@ -150,11 +148,13 @@ struct SkipLayerNorm : public PatternBase { class EmbeddingEltwiseLayerNormFusePass : public FusePassBase { public: + EmbeddingEltwiseLayerNormFusePass(); virtual ~EmbeddingEltwiseLayerNormFusePass() {} protected: void ApplyImpl(Graph* graph) const; - + int BuildFusion(Graph* graph, const std::string& name_scope + /*const Scope* scope*/) const; const std::string name_scope_{"embedding_eltwise_layernorm_fuse"}; };