提交 3b0eaee9 编写于 作者: W Wangzheee

support preln_ernie

上级 52bbaae9
...@@ -103,6 +103,8 @@ target_link_libraries(generate_pass pass_desc_proto) ...@@ -103,6 +103,8 @@ target_link_libraries(generate_pass pass_desc_proto)
if(WITH_TENSORRT) if(WITH_TENSORRT)
pass_library(trt_map_matmul_to_mul_pass inference) pass_library(trt_map_matmul_to_mul_pass inference)
pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(preln_skip_layernorm_fuse_pass inference)
endif() endif()
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
const std::string& arg,
bool is_persist = false) {
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
PDNode* node =
pattern->NewNode(name)->assert_is_ops_input(embedding_ops, arg);
if (is_persist) return node->assert_is_persistable_var();
return node;
}
static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name,
const std::string& arg) {
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
PDNode* node = pattern->NewNode(name)
->assert_is_only_output_of_ops(embedding_ops)
->assert_is_op_input("elementwise_add", arg)
->AsIntermediate();
return node;
}
void PrelnEmbedding2Eltwise1Pattern::operator()() {
auto* lookup_table1_x =
create_emb_vars(pattern, lookup_table1_x_repr(), "Ids");
auto* lookup_table2_x =
create_emb_vars(pattern, lookup_table2_x_repr(), "Ids");
auto* lookup_table1_w =
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
auto* lookup_table2_w =
create_emb_vars(pattern, lookup_table2_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
auto* lookup_table2 =
pattern->NewNode(lookup_table2_repr())->assert_is_ops(embedding_ops);
auto* lookup_table1_out =
create_emb_out_vars(pattern, lookup_table1_out_repr(), "X");
auto* lookup_table2_out =
create_emb_out_vars(pattern, lookup_table2_out_repr(), "Y");
auto* eltwise_add =
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
->assert_is_op_output("elementwise_add");
lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out});
lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w})
.LinksTo({lookup_table2_out});
eltwise_add->LinksFrom({lookup_table1_out, lookup_table2_out})
.LinksTo({eltwise_add_out});
}
void PrelnEmbedding1Eltwise1Pattern::operator()() {
auto* lookup_table1_x =
create_emb_vars(pattern, lookup_table1_x_repr(), "Ids");
auto* lookup_table1_w =
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
auto* lookup_table1_out =
create_emb_out_vars(pattern, lookup_table1_out_repr(), "Y");
auto* eltwise_add =
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_in = pattern->NewNode(eltwise_add_in_repr())
->assert_is_op_input("elementwise_add", "X")
->assert_is_op_output("elementwise_add");
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
->assert_is_op_output("elementwise_add");
lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out});
eltwise_add->LinksFrom({lookup_table1_out, eltwise_add_in})
.LinksTo({eltwise_add_out});
}
void PrelnSkipLayerNorm::operator()() {
auto* eltwise_add =
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
->assert_is_op_output("elementwise_add")
->assert_is_op_input("layer_norm", "X")
->assert_is_op_input("elementwise_add", "Y");
auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_out = pattern->NewNode(layer_norm_out_repr())
->assert_is_op_output("layer_norm", "Y")
->AsOutput();
auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
eltwise_add->LinksTo({eltwise_add_out});
layer_norm
->LinksFrom({eltwise_add_out, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo({layer_norm_out, layer_norm_mean_var, layer_norm_variance_var});
}
} // namespace patterns
int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
Graph* graph, const std::string& name_scope
/*const Scope* scope*/) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
std::vector<std::vector<std::pair<Node*, Node*>>> start_pattern_in_nodes;
std::vector<Node*> start_pattern_out_node;
std::vector<std::unordered_set<Node*>> start_pattern_remove_nodes;
// Create pattern.
patterns::PrelnEmbedding2Eltwise1Pattern start_pattern(pattern,
name_scope + "/start");
start_pattern();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_x, lookup_table2_x, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_w, lookup_table2_w, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2, lookup_table2, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out,
start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_out, lookup_table2_out,
start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, start_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING)
<< "Pass(PrelnEmbedding2Eltwise1Pattern) in op compat failed.";
return;
}
std::vector<std::pair<Node*, Node*>> ins;
ins.push_back(std::make_pair(lookup_table1_x, lookup_table1_w));
ins.push_back(std::make_pair(lookup_table2_x, lookup_table2_w));
start_pattern_in_nodes.push_back(ins);
start_pattern_out_node.push_back(eltwise_add_out);
std::unordered_set<Node*> rm_nodes;
rm_nodes.insert({lookup_table1, lookup_table2, lookup_table1_out,
lookup_table2_out, eltwise_add, eltwise_add_out});
start_pattern_remove_nodes.push_back(rm_nodes);
};
gpd(graph, handler);
std::vector<std::pair<Node*, Node*>> inner_pattern_ins;
std::vector<Node*> inner_pattern_tmp_in;
std::vector<Node*> inner_pattern_out;
std::vector<std::unordered_set<Node*>> inner_pattern_remove_nodes;
GraphPatternDetector gpd2;
auto* pattern2 = gpd2.mutable_pattern();
patterns::PrelnEmbedding1Eltwise1Pattern second_pattern(
pattern2, name_scope + "/second");
second_pattern();
auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out,
second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_in, eltwise_add_in, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, second_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING)
<< "Pass(PrelnEmbedding1Eltwise1Pattern) in op compat failed.";
return;
}
auto in = std::make_pair(lookup_table1_x, lookup_table1_w);
inner_pattern_ins.push_back(in);
inner_pattern_tmp_in.push_back(eltwise_add_in);
inner_pattern_out.push_back(eltwise_add_out);
std::unordered_set<Node*> rm_nodes;
rm_nodes.insert({lookup_table1, lookup_table1_out, eltwise_add});
inner_pattern_remove_nodes.push_back(rm_nodes);
};
gpd2(graph, handler2);
std::vector<Node*> end_pattern_elt_out;
std::vector<Node*> end_pattern_scales;
std::vector<Node*> end_pattern_biases;
std::vector<Node*> end_pattern_out;
std::vector<Node*> end_patter_layernorms;
std::vector<Node*> end_patter_elementwise;
std::vector<std::unordered_set<Node*>> end_pattern_remove_nodes;
GraphPatternDetector gpd3;
auto* pattern3 = gpd3.mutable_pattern();
patterns::PrelnSkipLayerNorm skip_layernorm_pattern(pattern3,
name_scope + "/third");
skip_layernorm_pattern();
auto handler3 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance,
skip_layernorm_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "Pass(PrelnSkipLayerNorm) in op compat failed.";
return;
}
end_pattern_elt_out.push_back(eltwise_add_out);
std::unordered_set<Node*> rm_nodes;
rm_nodes.insert({layer_norm, layer_norm_mean, layer_norm_variance});
end_pattern_remove_nodes.push_back(rm_nodes);
end_pattern_biases.push_back(layer_norm_bias);
end_pattern_scales.push_back(layer_norm_scale);
end_pattern_out.push_back(layer_norm_out);
end_patter_layernorms.push_back(layer_norm);
end_patter_elementwise.push_back(eltwise_add);
};
gpd3(graph, handler3);
if (start_pattern_in_nodes.empty() || end_pattern_elt_out.empty()) {
return 0;
}
// only reserve the subgraphs that in connected domains.
int fusion_count = 0;
// fusion_id for (i, k, js)
std::vector<std::pair<size_t, std::pair<size_t, std::vector<size_t>>>>
fusion_ids;
for (size_t i = 0; i < start_pattern_in_nodes.size(); ++i) {
Node* tmp = start_pattern_out_node[i];
Node* old_tmp = nullptr;
// get correct inner pattern node order.
std::vector<size_t> js;
while (tmp != old_tmp) {
old_tmp = tmp;
for (size_t j = 0; j < inner_pattern_tmp_in.size(); ++j) {
if (inner_pattern_tmp_in[j] == tmp) {
tmp = inner_pattern_out[j];
js.push_back(j);
break;
}
}
}
for (size_t k = 0; k < end_pattern_elt_out.size(); ++k) {
if (tmp == end_pattern_elt_out[k]) {
fusion_ids.push_back(std::make_pair(i, std::make_pair(k, js)));
break;
}
}
}
for (size_t num = 0; num < fusion_ids.size(); ++num) {
int i = fusion_ids[num].first;
int k = fusion_ids[num].second.first;
std::vector<size_t> js = fusion_ids[num].second.second;
std::vector<std::string> ids;
std::vector<std::string> embs;
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
ids.push_back(start_pattern_in_nodes[i][iter].first->Name());
embs.push_back(start_pattern_in_nodes[i][iter].second->Name());
}
for (size_t iter = 0; iter < js.size(); ++iter) {
ids.push_back(inner_pattern_ins[js[iter]].first->Name());
embs.push_back(inner_pattern_ins[js[iter]].second->Name());
}
OpDesc new_op_desc;
new_op_desc.SetType("fused_preln_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids);
new_op_desc.SetInput("Embs", embs);
new_op_desc.SetInput("WordId", {ids[0]});
new_op_desc.SetInput("PosId", {ids[1]});
if (ids.size() > 2) {
new_op_desc.SetInput("SentId", {ids[2]});
}
new_op_desc.SetInput("WordEmbedding", {embs[0]});
new_op_desc.SetInput("PosEmbedding", {embs[1]});
if (embs.size() > 2) {
new_op_desc.SetInput("SentEmbedding", {embs[2]});
}
new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()});
new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()});
new_op_desc.SetOutput("Out_0", {end_pattern_out[k]->Name()});
new_op_desc.SetOutput("Out_1", {inner_pattern_out[k]->Name()});
new_op_desc.SetAttr("epsilon",
end_patter_layernorms[k]->Op()->GetAttr("epsilon"));
if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold") &&
end_patter_elementwise[k]->Op()->HasAttr("out_threshold")) {
new_op_desc.SetAttr("enable_int8", true);
new_op_desc.SetAttr(
"out_0_threshold",
end_patter_layernorms[k]->Op()->GetAttr("out_threshold"));
new_op_desc.SetAttr(
"out_1_threshold",
end_patter_elementwise[k]->Op()->GetAttr("out_threshold"));
}
auto* preln_embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc);
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first,
preln_embedding_eltwise_layernorm);
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second,
preln_embedding_eltwise_layernorm);
}
for (size_t iter = 0; iter < js.size(); ++iter) {
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first,
preln_embedding_eltwise_layernorm);
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second,
preln_embedding_eltwise_layernorm);
}
IR_NODE_LINK_TO(end_pattern_biases[k], preln_embedding_eltwise_layernorm);
IR_NODE_LINK_TO(end_pattern_scales[k], preln_embedding_eltwise_layernorm);
IR_NODE_LINK_TO(preln_embedding_eltwise_layernorm, end_pattern_out[k]);
IR_NODE_LINK_TO(embedding_eltwise_layernorm, inner_pattern_out[k]);
// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes;
marked_nodes.insert(start_pattern_remove_nodes[i].begin(),
start_pattern_remove_nodes[i].end());
marked_nodes.insert(end_pattern_remove_nodes[k].begin(),
end_pattern_remove_nodes[k].end());
for (size_t iter = 0; iter < js.size(); ++iter) {
marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(),
inner_pattern_remove_nodes[js[iter]].end());
}
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
}
return fusion_count;
}
PrelnEmbeddingEltwiseLayerNormFusePass::
PrelnEmbeddingEltwiseLayerNormFusePass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.End();
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
}
void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count =
PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(graph, name_scope_);
if (fusion_count > 0) {
graph->Set(kPrelnEmbEltwiseLayernormPass, new bool(true));
}
AddStatis(fusion_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(preln_embedding_eltwise_layernorm_fuse_pass,
paddle::framework::ir::PrelnEmbeddingEltwiseLayerNormFusePass);
REGISTER_PASS_CAPABILITY(preln_embedding_eltwise_layernorm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("lookup_table", 1)
.LE("lookup_table_v2", 1)
.LE("elementweise_add", 1));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
// detect start pattern.
//
// in_var emb in_var emb
// | | | |
// lookup_table lookup_table
// | |
// lkt_var lkt_var
// \ /
// elementwise_add
// |
// elt_out_var
//
struct PrelnEmbedding2Eltwise1Pattern : public PatternBase {
PrelnEmbedding2Eltwise1Pattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "Prelnembedding2_eltwise1") {}
void operator()();
PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table2_x);
PATTERN_DECL_NODE(lookup_table1_w);
PATTERN_DECL_NODE(lookup_table2_w);
PATTERN_DECL_NODE(lookup_table1);
PATTERN_DECL_NODE(lookup_table2);
PATTERN_DECL_NODE(lookup_table1_out);
PATTERN_DECL_NODE(lookup_table2_out);
PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out);
};
// detect repeats inner pattern
//
// elt_out_var in_var emb
// \ | |
// \ lookup_table
// \ |
// \ lkt_var
// \ /
// elementwise_add
// | |
// elementwise_add elt_out_var
//
struct PrelnEmbedding1Eltwise1Pattern : public PatternBase {
PrelnEmbedding1Eltwise1Pattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "Prelnembedding1_eltwise1") {}
void operator()();
PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table1_w);
PATTERN_DECL_NODE(lookup_table1);
PATTERN_DECL_NODE(lookup_table1_out);
PATTERN_DECL_NODE(eltwise_add_in);
PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out);
};
// detect end pattern
//
// elementwise_add
// | |
// | elt_out_var
// | scale | bias
// | \ | /
// elementwise_add layer_norm
//
struct PrelnSkipLayerNorm : public PatternBase {
PrelnSkipLayerNorm(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "Prelnskip_layernorm") {}
void operator()();
PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_out);
// Delete the mean and var nodes in the graph.
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
};
} // namespace patterns
// The PrelnEmbeddingEltwiseLayerNormFusePass detect the following pattern:
//
// inputs operator output
// --------------------------------------------------------------------
// (word, weights_0) lookup_table -> word_emb
// (pos, weights_1) lookup_table -> pos_emb
// (sent, weights_2) lookup_table -> sent_emb
// (word_emb, pos_emb) elementweise_add -> elementwise_out_0
// (elemtwise_out_0, sent_emb) elementweise_add -> elementwise_out_1
// (elementwise_out_1, scale, bias) layer_norm -> layer_norm_out
//
// and then convert the corresponding subgraph to:
//
// (word, pos, sent, weights_0, weights_1, weights_2,
// scale, baias) Prelnembedding_eltwise_layernorm -> layer_norm_out +
// elementwise_add_out
//
//
// in_var emb_var in_var emb_var in_var emb_var in_var emb_var
// | | | | | | | |
// lookup_table lookup_table lookup_table ... lookup_table
// | | | |
// lkt_var lkt_var lkt_var lkt_var
// \ / | ... |
// elementwise_add | |
// \ / |
// elementwise_add |
// | |
// elt_var /
// \ /
// elementwise_add
// | |
// elementwise_add layer_norm
class PrelnEmbeddingEltwiseLayerNormFusePass : public FusePassBase {
public:
PrelnEmbeddingEltwiseLayerNormFusePass();
virtual ~PrelnEmbeddingEltwiseLayerNormFusePass() {}
protected:
void ApplyImpl(Graph* graph) const;
int BuildFusion(Graph* graph, const std::string& name_scope
/*const Scope* scope*/) const;
const std::string name_scope_{"preln_embedding_eltwise_layernorm_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct PrelnSkipLayerNorm : public PatternBase {
PrelnSkipLayerNorm(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "preln_skip_layernorm") {}
void operator()(PDNode *x, PDNode *y);
// declare operator node's name
PATTERN_DECL_NODE(fused_skipe_layernorm);
PATTERN_DECL_NODE(elementwise);
PATTERN_DECL_NODE(layer_norm);
// declare variable node's name
PATTERN_DECL_NODE(
elementwise_out); // (elementwise_input_x,elementwise_input_y) ->
// elementwise_out
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
};
void *PrelnSkipLayerNorm::operator()(PDNode *x, PDNode *y) {
// Create nodes for elementwise add op.
x->assert_is_op_input("elementwise_add", "X");
y->assert_is_op_input("elementwise_add", "Y");
auto *elementwise =
pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add");
auto *elementwise_out_var = pattern->NewNode(elementwise_out_repr())
->assert_is_op_output("elementwise_add")
->assert_is_op_input("layer_norm", "X");
->assert_is_op_input("elementwise_add", "Y");
// Add links for elementwise_add op.
elementwise->LinksFrom({x, y}).LinksTo({elementwise_out_var});
// Create nodes for layer_norm op.
auto *layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto *layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto *layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto *layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Y");
auto *layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto *layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
// Add links for layer_norm op.
layer_norm
->LinksFrom(
{elementwise_out_var, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
}
} // namespace patterns
void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("preln_skip_layernorm_fuse", graph);
int found_subgraph_count = 0;
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
->NewNode("preln_skip_layernorm_fuse/x")
->AsInput()
->assert_is_op_input("elementwise_add", "X")
->assert_var_not_persistable();
auto *y = gpd.mutable_pattern()
->NewNode("preln_skip_layernorm_fuse/y")
->AsInput()
->assert_is_op_input("elementwise_add", "Y")
->assert_var_not_persistable();
patterns::PrelnSkipLayerNorm fused_pattern(gpd.mutable_pattern(),
"preln_skip_layernorm_fuse");
fused_pattern(x, y);
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) {
LOG(WARNING) << "The subgraph is empty.";
return;
}
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "preln_skip_layernorm pass in op compat failed.";
return;
}
VLOG(4) << "handle PrelnSkipLayerNorm fuse";
GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale,
fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance,
fused_pattern);
std::unordered_set<const Node *> del_node_set;
// Create an PrelnSkipLayerNorm op node
OpDesc new_desc;
new_desc.SetType("preln_skip_layernorm");
// inputs
new_desc.SetInput("X", {subgraph.at(x)->Name()});
new_desc.SetInput("Y", {subgraph.at(y)->Name()});
new_desc.SetInput("Scale", {layer_norm_scale->Name()});
new_desc.SetInput("Bias", {layer_norm_bias->Name()});
if (elementwise->Op()->HasAttr("out_threshold") &&
layer_norm->Op()->HasAttr("out_threshold")) {
new_desc.SetAttr("enable_int8", true);
new_desc.SetAttr("out_0_threshold",
layer_norm->Op()->GetAttr("out_threshold"));
new_desc.SetAttr("out_1_threshold",
elementwise->Op()->GetAttr("out_threshold"));
}
// outputs
new_desc.SetOutput("Out_0", {layer_norm_out->Name()});
new_desc.SetOutput("Out_1", {elementwise_out->Name()});
// attrs
new_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon"));
new_desc.SetAttr("begin_norm_axis",
layer_norm->Op()->GetAttr("begin_norm_axis"));
auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied.
del_node_set.insert(elementwise);
del_node_set.insert(layer_norm);
del_node_set.insert(layer_norm_mean);
del_node_set.insert(layer_norm_variance);
GraphSafeRemoveNodes(graph, del_node_set);
IR_NODE_LINK_TO(subgraph.at(x), fused_node);
IR_NODE_LINK_TO(subgraph.at(y), fused_node);
IR_NODE_LINK_TO(layer_norm_scale, fused_node);
IR_NODE_LINK_TO(layer_norm_bias, fused_node);
IR_NODE_LINK_TO(fused_node, layer_norm_out);
IR_NODE_LINK_TO(fused_node, elementwise_out);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(preln_skip_layernorm_fuse_pass,
paddle::framework::ir::PrelnSkipLayerNormFusePass);
REGISTER_PASS_CAPABILITY(preln_skip_layernorm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("layer_norm", 0));
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
// | | | |
// other_op1 other_op2 other_op1 other_op2
// | | fuse \ /
// |------elementwise_add -> skip_layernorm
// | | | |
// other_op4 layer_norm other_op4 other_op3
// |
// other_op3
class Graph;
class PrelnSkipLayerNormFusePass : public FusePassBase {
public:
PrelnSkipLayerNormFusePass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({0, -1})
.End();
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
}
virtual ~PrelnSkipLayerNormFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -82,22 +82,24 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -82,22 +82,24 @@ const std::vector<std::string> kTRTSubgraphPasses({
"quant_conv2d_dequant_fuse_pass", // "quant_conv2d_dequant_fuse_pass", //
"delete_quant_dequant_op_pass", // "delete_quant_dequant_op_pass", //
"delete_quant_dequant_filter_op_pass", // "delete_quant_dequant_filter_op_pass", //
// "fc_fuse_pass", // // "fc_fuse_pass", //
"simplify_with_basic_ops_pass", // "simplify_with_basic_ops_pass", //
"embedding_eltwise_layernorm_fuse_pass", // "embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", // "preln_embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v3", // "multihead_matmul_fuse_pass_v2", //
"skip_layernorm_fuse_pass", // "multihead_matmul_fuse_pass_v3", //
"conv_bn_fuse_pass", // "skip_layernorm_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", // "preln_skip_layernorm_fuse_pass", //
"trt_squeeze2_matmul_fuse_pass", // "conv_bn_fuse_pass", //
"trt_reshape2_matmul_fuse_pass", // "unsqueeze2_eltwise_fuse_pass", //
"trt_flatten2_matmul_fuse_pass", // "trt_squeeze2_matmul_fuse_pass", //
"trt_map_matmul_v2_to_mul_pass", // "trt_reshape2_matmul_fuse_pass", //
"trt_map_matmul_v2_to_matmul_pass", // "trt_flatten2_matmul_fuse_pass", //
"trt_map_matmul_to_mul_pass", // "trt_map_matmul_v2_to_mul_pass", //
"fc_fuse_pass", // "trt_map_matmul_v2_to_matmul_pass", //
"conv_elementwise_add_fuse_pass", // "trt_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
"add_support_int8_pass", "add_support_int8_pass",
"tensorrt_subgraph_pass", // "tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册