未验证 提交 ae74c404 编写于 作者: W Wangzheee 提交者: GitHub

[pass_enhance] embedding_eltwise_layernorm_fuse_pass (#33973)

上级 f2068eec
...@@ -136,8 +136,12 @@ void SkipLayerNorm::operator()() { ...@@ -136,8 +136,12 @@ void SkipLayerNorm::operator()() {
->LinksFrom({eltwise_add_out, layer_norm_bias_var, layer_norm_scale_var}) ->LinksFrom({eltwise_add_out, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo({layer_norm_out, layer_norm_mean_var, layer_norm_variance_var}); .LinksTo({layer_norm_out, layer_norm_mean_var, layer_norm_variance_var});
} }
static int BuildFusion(Graph* graph, const std::string& name_scope
/*const Scope* scope*/) { } // namespace patterns
int EmbeddingEltwiseLayerNormFusePass::BuildFusion(
Graph* graph, const std::string& name_scope
/*const Scope* scope*/) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
...@@ -146,7 +150,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope ...@@ -146,7 +150,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope
std::vector<std::unordered_set<Node*>> start_pattern_remove_nodes; std::vector<std::unordered_set<Node*>> start_pattern_remove_nodes;
// Create pattern. // Create pattern.
Embedding2Eltwise1Pattern start_pattern(pattern, name_scope + "/start"); patterns::Embedding2Eltwise1Pattern start_pattern(pattern,
name_scope + "/start");
start_pattern(); start_pattern();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
...@@ -162,6 +167,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope ...@@ -162,6 +167,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope
start_pattern); start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, start_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, start_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, start_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "Pass(Embedding2Eltwise1Pattern) in op compat failed.";
return;
}
std::vector<std::pair<Node*, Node*>> ins; std::vector<std::pair<Node*, Node*>> ins;
ins.push_back(std::make_pair(lookup_table1_x, lookup_table1_w)); ins.push_back(std::make_pair(lookup_table1_x, lookup_table1_w));
ins.push_back(std::make_pair(lookup_table2_x, lookup_table2_w)); ins.push_back(std::make_pair(lookup_table2_x, lookup_table2_w));
...@@ -182,7 +191,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope ...@@ -182,7 +191,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope
GraphPatternDetector gpd2; GraphPatternDetector gpd2;
auto* pattern2 = gpd2.mutable_pattern(); auto* pattern2 = gpd2.mutable_pattern();
Embedding1Eltwise1Pattern second_pattern(pattern2, name_scope + "/second"); patterns::Embedding1Eltwise1Pattern second_pattern(pattern2,
name_scope + "/second");
second_pattern(); second_pattern();
auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
...@@ -194,6 +204,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope ...@@ -194,6 +204,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_in, eltwise_add_in, second_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_in, eltwise_add_in, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, second_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, second_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, second_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "Pass(Embedding1Eltwise1Pattern) in op compat failed.";
return;
}
auto in = std::make_pair(lookup_table1_x, lookup_table1_w); auto in = std::make_pair(lookup_table1_x, lookup_table1_w);
inner_pattern_ins.push_back(in); inner_pattern_ins.push_back(in);
inner_pattern_tmp_in.push_back(eltwise_add_in); inner_pattern_tmp_in.push_back(eltwise_add_in);
...@@ -214,7 +228,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope ...@@ -214,7 +228,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope
std::vector<std::unordered_set<Node*>> end_pattern_remove_nodes; std::vector<std::unordered_set<Node*>> end_pattern_remove_nodes;
GraphPatternDetector gpd3; GraphPatternDetector gpd3;
auto* pattern3 = gpd3.mutable_pattern(); auto* pattern3 = gpd3.mutable_pattern();
SkipLayerNorm skip_layernorm_pattern(pattern3, name_scope + "/third"); patterns::SkipLayerNorm skip_layernorm_pattern(pattern3,
name_scope + "/third");
skip_layernorm_pattern(); skip_layernorm_pattern();
auto handler3 = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler3 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
...@@ -232,6 +247,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope ...@@ -232,6 +247,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope
skip_layernorm_pattern); skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance,
skip_layernorm_pattern); skip_layernorm_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "Pass(SkipLayerNorm) in op compat failed.";
return;
}
end_pattern_elt_out.push_back(eltwise_add_out); end_pattern_elt_out.push_back(eltwise_add_out);
std::unordered_set<Node*> rm_nodes; std::unordered_set<Node*> rm_nodes;
rm_nodes.insert({layer_norm, layer_norm_mean, layer_norm_variance}); rm_nodes.insert({layer_norm, layer_norm_mean, layer_norm_variance});
...@@ -349,11 +368,53 @@ static int BuildFusion(Graph* graph, const std::string& name_scope ...@@ -349,11 +368,53 @@ static int BuildFusion(Graph* graph, const std::string& name_scope
return fusion_count; return fusion_count;
} }
} // namespace patterns EmbeddingEltwiseLayerNormFusePass::EmbeddingEltwiseLayerNormFusePass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({0, -1})
.End();
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
}
void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
int fusion_count = patterns::BuildFusion(graph, name_scope_); int fusion_count =
EmbeddingEltwiseLayerNormFusePass::BuildFusion(graph, name_scope_);
if (fusion_count > 0) { if (fusion_count > 0) {
graph->Set(kEmbEltwiseLayernormPass, new bool(true)); graph->Set(kEmbEltwiseLayernormPass, new bool(true));
} }
......
...@@ -19,8 +19,6 @@ ...@@ -19,8 +19,6 @@
#include <utility> #include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -150,11 +148,13 @@ struct SkipLayerNorm : public PatternBase { ...@@ -150,11 +148,13 @@ struct SkipLayerNorm : public PatternBase {
class EmbeddingEltwiseLayerNormFusePass : public FusePassBase { class EmbeddingEltwiseLayerNormFusePass : public FusePassBase {
public: public:
EmbeddingEltwiseLayerNormFusePass();
virtual ~EmbeddingEltwiseLayerNormFusePass() {} virtual ~EmbeddingEltwiseLayerNormFusePass() {}
protected: protected:
void ApplyImpl(Graph* graph) const; void ApplyImpl(Graph* graph) const;
int BuildFusion(Graph* graph, const std::string& name_scope
/*const Scope* scope*/) const;
const std::string name_scope_{"embedding_eltwise_layernorm_fuse"}; const std::string name_scope_{"embedding_eltwise_layernorm_fuse"};
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册