未验证 提交 ae74c404 编写于 作者: W Wangzheee 提交者: GitHub

[pass_enhance] embedding_eltwise_layernorm_fuse_pass (#33973)

上级 f2068eec
......@@ -136,8 +136,12 @@ void SkipLayerNorm::operator()() {
->LinksFrom({eltwise_add_out, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo({layer_norm_out, layer_norm_mean_var, layer_norm_variance_var});
}
static int BuildFusion(Graph* graph, const std::string& name_scope
/*const Scope* scope*/) {
} // namespace patterns
int EmbeddingEltwiseLayerNormFusePass::BuildFusion(
Graph* graph, const std::string& name_scope
/*const Scope* scope*/) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
......@@ -146,7 +150,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope
std::vector<std::unordered_set<Node*>> start_pattern_remove_nodes;
// Create pattern.
Embedding2Eltwise1Pattern start_pattern(pattern, name_scope + "/start");
patterns::Embedding2Eltwise1Pattern start_pattern(pattern,
name_scope + "/start");
start_pattern();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
......@@ -162,6 +167,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope
start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, start_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "Pass(Embedding2Eltwise1Pattern) in op compat failed.";
return;
}
std::vector<std::pair<Node*, Node*>> ins;
ins.push_back(std::make_pair(lookup_table1_x, lookup_table1_w));
ins.push_back(std::make_pair(lookup_table2_x, lookup_table2_w));
......@@ -182,7 +191,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope
GraphPatternDetector gpd2;
auto* pattern2 = gpd2.mutable_pattern();
Embedding1Eltwise1Pattern second_pattern(pattern2, name_scope + "/second");
patterns::Embedding1Eltwise1Pattern second_pattern(pattern2,
name_scope + "/second");
second_pattern();
auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
......@@ -194,6 +204,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_in, eltwise_add_in, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, second_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "Pass(Embedding1Eltwise1Pattern) in op compat failed.";
return;
}
auto in = std::make_pair(lookup_table1_x, lookup_table1_w);
inner_pattern_ins.push_back(in);
inner_pattern_tmp_in.push_back(eltwise_add_in);
......@@ -214,7 +228,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope
std::vector<std::unordered_set<Node*>> end_pattern_remove_nodes;
GraphPatternDetector gpd3;
auto* pattern3 = gpd3.mutable_pattern();
SkipLayerNorm skip_layernorm_pattern(pattern3, name_scope + "/third");
patterns::SkipLayerNorm skip_layernorm_pattern(pattern3,
name_scope + "/third");
skip_layernorm_pattern();
auto handler3 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
......@@ -232,6 +247,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance,
skip_layernorm_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "Pass(SkipLayerNorm) in op compat failed.";
return;
}
end_pattern_elt_out.push_back(eltwise_add_out);
std::unordered_set<Node*> rm_nodes;
rm_nodes.insert({layer_norm, layer_norm_mean, layer_norm_variance});
......@@ -349,11 +368,53 @@ static int BuildFusion(Graph* graph, const std::string& name_scope
return fusion_count;
}
} // namespace patterns
EmbeddingEltwiseLayerNormFusePass::EmbeddingEltwiseLayerNormFusePass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({0, -1})
.End();
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
}
void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = patterns::BuildFusion(graph, name_scope_);
int fusion_count =
EmbeddingEltwiseLayerNormFusePass::BuildFusion(graph, name_scope_);
if (fusion_count > 0) {
graph->Set(kEmbEltwiseLayernormPass, new bool(true));
}
......
......@@ -19,8 +19,6 @@
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
......@@ -150,11 +148,13 @@ struct SkipLayerNorm : public PatternBase {
class EmbeddingEltwiseLayerNormFusePass : public FusePassBase {
public:
EmbeddingEltwiseLayerNormFusePass();
virtual ~EmbeddingEltwiseLayerNormFusePass() {}
protected:
void ApplyImpl(Graph* graph) const;
int BuildFusion(Graph* graph, const std::string& name_scope
/*const Scope* scope*/) const;
const std::string name_scope_{"embedding_eltwise_layernorm_fuse"};
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册