未验证 提交 95b356a0 编写于 作者: W Wilber 提交者: GitHub

update embedding_eltwise_layernorm fuse and kernel. test=develop (#23114)

update embedding_eltwise_layernorm fuse pass and fused kernel, to support multi input
上级 a31d7328
...@@ -130,6 +130,7 @@ cc_test(test_skip_layernorm_fuse_pass SRCS skip_layernorm_fuse_pass_tester.cc DE ...@@ -130,6 +130,7 @@ cc_test(test_skip_layernorm_fuse_pass SRCS skip_layernorm_fuse_pass_tester.cc DE
cc_test(test_multihead_matmul_fuse_pass SRCS multihead_matmul_fuse_pass_tester.cc DEPS multihead_matmul_fuse_pass) cc_test(test_multihead_matmul_fuse_pass SRCS multihead_matmul_fuse_pass_tester.cc DEPS multihead_matmul_fuse_pass)
cc_test(test_conv_bn_fuse_pass SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_fuse_pass) cc_test(test_conv_bn_fuse_pass SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_fuse_pass)
if(WITH_GPU) if(WITH_GPU)
cc_test(test_embedding_eltwise_layernorm_fuse_pass SRCS embedding_eltwise_layernorm_fuse_pass_tester.cc DEPS embedding_eltwise_layernorm_fuse_pass)
cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass) cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass)
endif() endif()
if(NOT WIN32) if(NOT WIN32)
......
...@@ -25,180 +25,76 @@ namespace framework { ...@@ -25,180 +25,76 @@ namespace framework {
namespace ir { namespace ir {
namespace patterns { namespace patterns {
static int BuildFusion(Graph* graph, const std::string& name_scope, static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
const Scope* scope) { const std::string& arg,
GraphPatternDetector gpd; bool is_persist = false) {
auto* pattern = gpd.mutable_pattern(); PDNode* node =
pattern->NewNode(name)->assert_is_op_input("lookup_table", arg);
// Create pattern. if (is_persist) return node->assert_is_persistable_var();
EmbeddingEltwiseLayerNormPattern emb_eltwise_layernorm_pattern(pattern, return node;
name_scope);
emb_eltwise_layernorm_pattern();
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_x, lookup_table2_x,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table3_x, lookup_table3_x,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_w, lookup_table2_w,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table3_w, lookup_table3_w,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2, lookup_table2,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table3, lookup_table3,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_out, lookup_table2_out,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table3_out, lookup_table3_out,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_12, eltwise_add_12,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_12_out, eltwise_add_12_out,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean,
emb_eltwise_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance,
emb_eltwise_layernorm_pattern);
auto get_persist_tensor_dims = [&](std::string name) -> framework::DDim {
auto* var = scope->FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var,
platform::errors::PreconditionNotMet(
"Cant not found the %d var in scope.", name));
return var->GetMutable<LoDTensor>()->dims();
};
// Check the weight dims.
auto word_emb_dims = get_persist_tensor_dims(lookup_table1_w->Name());
auto pos_emb_dims = get_persist_tensor_dims(lookup_table2_w->Name());
auto sent_emb_dims = get_persist_tensor_dims(lookup_table3_w->Name());
if (word_emb_dims.size() != 2 || pos_emb_dims.size() != 2 ||
sent_emb_dims.size() != 2 || word_emb_dims[1] != pos_emb_dims[1] ||
word_emb_dims[1] != sent_emb_dims[1]) {
return;
}
OpDesc new_op_desc;
new_op_desc.SetType("fused_embedding_eltwise_layernorm");
new_op_desc.SetInput("WordId", {lookup_table1_x->Name()});
new_op_desc.SetInput("PosId", {lookup_table2_x->Name()});
new_op_desc.SetInput("SentId", {lookup_table3_x->Name()});
new_op_desc.SetInput("WordEmb", {lookup_table1_w->Name()});
new_op_desc.SetInput("PosEmb", {lookup_table2_w->Name()});
new_op_desc.SetInput("SentEmb", {lookup_table3_w->Name()});
new_op_desc.SetInput("Bias", {layer_norm_bias->Name()});
new_op_desc.SetInput("Scale", {layer_norm_scale->Name()});
new_op_desc.SetOutput("Out", {layer_norm_out->Name()});
new_op_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon"));
auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc);
IR_NODE_LINK_TO(lookup_table1_x, embedding_eltwise_layernorm);
IR_NODE_LINK_TO(lookup_table2_x, embedding_eltwise_layernorm);
IR_NODE_LINK_TO(lookup_table3_x, embedding_eltwise_layernorm);
IR_NODE_LINK_TO(lookup_table1_w, embedding_eltwise_layernorm);
IR_NODE_LINK_TO(lookup_table2_w, embedding_eltwise_layernorm);
IR_NODE_LINK_TO(lookup_table3_w, embedding_eltwise_layernorm);
IR_NODE_LINK_TO(layer_norm_bias, embedding_eltwise_layernorm);
IR_NODE_LINK_TO(layer_norm_scale, embedding_eltwise_layernorm);
IR_NODE_LINK_TO(embedding_eltwise_layernorm, layer_norm_out);
std::unordered_set<const Node*> marked_nodes(
{lookup_table1, lookup_table2, lookup_table3, lookup_table1_out,
lookup_table2_out, lookup_table3_out, eltwise_add_12,
eltwise_add_12_out, eltwise_add, eltwise_add_out, layer_norm,
layer_norm_mean, layer_norm_variance});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
} }
static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name,
PDNode* EmbeddingEltwiseLayerNormPattern::operator()() { const std::string& arg) {
// Create shared nodes. PDNode* node = pattern->NewNode(name)
auto create_emb_vars = [&](const std::string& name, const std::string& arg, ->assert_is_op_output("lookup_table")
bool is_persist = false) -> PDNode* { ->assert_is_op_input("elementwise_add", arg);
PDNode* node = pattern->NewNode(name) return node;
->assert_is_op_input("lookup_table", arg) }
->AsInput(); void Embedding2Eltwise1Pattern::operator()() {
if (is_persist) return node->assert_is_persistable_var(); auto* lookup_table1_x =
return node; create_emb_vars(pattern, lookup_table1_x_repr(), "Ids");
}; auto* lookup_table2_x =
create_emb_vars(pattern, lookup_table2_x_repr(), "Ids");
auto create_emb_out_vars = [&](const std::string& name, auto* lookup_table1_w =
const std::string& arg) -> PDNode* { create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
PDNode* node = pattern->NewNode(name) auto* lookup_table2_w =
->AsIntermediate() create_emb_vars(pattern, lookup_table2_w_repr(), "W", true);
->assert_is_op_output("lookup_table")
->assert_is_op_input("elementwise_add", arg);
return node;
};
auto* lookup_table1_x = create_emb_vars(lookup_table1_x_repr(), "Ids");
auto* lookup_table2_x = create_emb_vars(lookup_table2_x_repr(), "Ids");
auto* lookup_table3_x = create_emb_vars(lookup_table3_x_repr(), "Ids");
auto* lookup_table1_w = create_emb_vars(lookup_table1_w_repr(), "W", true);
auto* lookup_table2_w = create_emb_vars(lookup_table2_w_repr(), "W", true);
auto* lookup_table3_w = create_emb_vars(lookup_table3_w_repr(), "W", true);
auto* lookup_table1 = auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_op("lookup_table"); pattern->NewNode(lookup_table1_repr())->assert_is_op("lookup_table");
auto* lookup_table2 = auto* lookup_table2 =
pattern->NewNode(lookup_table2_repr())->assert_is_op("lookup_table"); pattern->NewNode(lookup_table2_repr())->assert_is_op("lookup_table");
auto* lookup_table3 = auto* lookup_table1_out =
pattern->NewNode(lookup_table3_repr())->assert_is_op("lookup_table"); create_emb_out_vars(pattern, lookup_table1_out_repr(), "X");
auto* lookup_table2_out =
auto* lookup_table1_out = create_emb_out_vars(lookup_table1_out_repr(), "X"); create_emb_out_vars(pattern, lookup_table2_out_repr(), "Y");
auto* lookup_table2_out = create_emb_out_vars(lookup_table2_out_repr(), "Y");
auto* lookup_table3_out = create_emb_out_vars(lookup_table3_out_repr(), "Y");
auto* eltwise_add_12 =
pattern->NewNode(eltwise_add_12_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_12_out = pattern->NewNode(eltwise_add_12_out_repr())
->AsIntermediate()
->assert_is_op_output("elementwise_add")
->assert_is_op_input("elementwise_add", "X");
auto* eltwise_add = auto* eltwise_add =
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr()) auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
->AsIntermediate()
->assert_is_op_output("elementwise_add"); ->assert_is_op_output("elementwise_add");
lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out});
lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w})
.LinksTo({lookup_table2_out});
eltwise_add->LinksFrom({lookup_table1_out, lookup_table2_out})
.LinksTo({eltwise_add_out});
}
void Embedding1Eltwise1Pattern::operator()() {
auto* lookup_table1_x =
create_emb_vars(pattern, lookup_table1_x_repr(), "Ids");
auto* lookup_table1_w =
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_op("lookup_table");
auto* lookup_table1_out =
create_emb_out_vars(pattern, lookup_table1_out_repr(), "Y");
auto* eltwise_add =
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_in = pattern->NewNode(eltwise_add_in_repr())
->assert_is_op_input("elementwise_add", "X")
->assert_is_op_output("elementwise_add");
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
->assert_is_op_output("elementwise_add");
lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out});
eltwise_add->LinksFrom({lookup_table1_out, eltwise_add_in})
.LinksTo({eltwise_add_out});
}
void SkipLayerNorm::operator()() {
auto* eltwise_add =
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
->assert_is_op_output("elementwise_add")
->assert_is_op_input("layer_norm", "X");
auto* layer_norm = auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_out = pattern->NewNode(layer_norm_out_repr()) auto* layer_norm_out = pattern->NewNode(layer_norm_out_repr())
...@@ -212,7 +108,6 @@ PDNode* EmbeddingEltwiseLayerNormPattern::operator()() { ...@@ -212,7 +108,6 @@ PDNode* EmbeddingEltwiseLayerNormPattern::operator()() {
->AsInput() ->AsInput()
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale"); ->assert_is_op_input("layer_norm", "Scale");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output("layer_norm", "Mean"); ->assert_is_op_output("layer_norm", "Mean");
...@@ -220,33 +115,214 @@ PDNode* EmbeddingEltwiseLayerNormPattern::operator()() { ...@@ -220,33 +115,214 @@ PDNode* EmbeddingEltwiseLayerNormPattern::operator()() {
pattern->NewNode(layer_norm_variance_repr()) pattern->NewNode(layer_norm_variance_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output("layer_norm", "Variance"); ->assert_is_op_output("layer_norm", "Variance");
eltwise_add->LinksTo({eltwise_add_out});
// Link all nodes together
lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out});
lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w})
.LinksTo({lookup_table2_out});
lookup_table3->LinksFrom({lookup_table3_x, lookup_table3_w})
.LinksTo({lookup_table3_out});
eltwise_add_12->LinksFrom({lookup_table1_out, lookup_table2_out})
.LinksTo({eltwise_add_12_out});
eltwise_add->LinksFrom({lookup_table3_out, eltwise_add_12_out})
.LinksTo({eltwise_add_out});
layer_norm layer_norm
->LinksFrom({eltwise_add_out, layer_norm_bias_var, layer_norm_scale_var}) ->LinksFrom({eltwise_add_out, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo({layer_norm_out, layer_norm_mean_var, layer_norm_variance_var}); .LinksTo({layer_norm_out, layer_norm_mean_var, layer_norm_variance_var});
return layer_norm_out; }
static int BuildFusion(Graph* graph, const std::string& name_scope
/*const Scope* scope*/) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
std::vector<std::vector<std::pair<Node*, Node*>>> start_pattern_in_nodes;
std::vector<Node*> start_pattern_out_node;
std::vector<std::unordered_set<Node*>> start_pattern_remove_nodes;
// Create pattern.
Embedding2Eltwise1Pattern start_pattern(pattern, name_scope + "/start");
start_pattern();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_x, lookup_table2_x, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_w, lookup_table2_w, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2, lookup_table2, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out,
start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_out, lookup_table2_out,
start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, start_pattern);
std::vector<std::pair<Node*, Node*>> ins;
ins.push_back(std::make_pair(lookup_table1_x, lookup_table1_w));
ins.push_back(std::make_pair(lookup_table2_x, lookup_table2_w));
start_pattern_in_nodes.push_back(ins);
start_pattern_out_node.push_back(eltwise_add_out);
std::unordered_set<Node*> rm_nodes;
rm_nodes.insert({lookup_table1, lookup_table2, lookup_table1_out,
lookup_table2_out, eltwise_add, eltwise_add_out});
start_pattern_remove_nodes.push_back(rm_nodes);
};
gpd(graph, handler);
std::vector<std::pair<Node*, Node*>> inner_pattern_ins;
std::vector<Node*> inner_pattern_tmp_in;
std::vector<Node*> inner_pattern_out;
std::vector<std::unordered_set<Node*>> inner_pattern_remove_nodes;
GraphPatternDetector gpd2;
auto* pattern2 = gpd2.mutable_pattern();
Embedding1Eltwise1Pattern second_pattern(pattern2, name_scope + "/second");
second_pattern();
auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out,
second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_in, eltwise_add_in, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, second_pattern);
auto in = std::make_pair(lookup_table1_x, lookup_table1_w);
inner_pattern_ins.push_back(in);
inner_pattern_tmp_in.push_back(eltwise_add_in);
inner_pattern_out.push_back(eltwise_add_out);
std::unordered_set<Node*> rm_nodes;
rm_nodes.insert(
{lookup_table1, lookup_table1_out, eltwise_add, eltwise_add_out});
inner_pattern_remove_nodes.push_back(rm_nodes);
};
gpd2(graph, handler2);
std::vector<Node*> end_pattern_elt_out;
std::vector<Node*> end_pattern_scales;
std::vector<Node*> end_pattern_biases;
std::vector<Node*> end_pattern_out;
std::vector<Node*> end_patter_layernorms;
std::vector<std::unordered_set<Node*>> end_pattern_remove_nodes;
GraphPatternDetector gpd3;
auto* pattern3 = gpd3.mutable_pattern();
SkipLayerNorm skip_layernorm_pattern(pattern3, name_scope + "/third");
skip_layernorm_pattern();
auto handler3 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance,
skip_layernorm_pattern);
end_pattern_elt_out.push_back(eltwise_add_out);
std::unordered_set<Node*> rm_nodes;
rm_nodes.insert({layer_norm, layer_norm_mean, layer_norm_variance});
end_pattern_remove_nodes.push_back(rm_nodes);
end_pattern_biases.push_back(layer_norm_bias);
end_pattern_scales.push_back(layer_norm_scale);
end_pattern_out.push_back(layer_norm_out);
end_patter_layernorms.push_back(layer_norm);
};
gpd3(graph, handler3);
if (start_pattern_in_nodes.empty() || end_pattern_elt_out.empty()) {
return 0;
}
// only reserve the subgraphs that in connected domains.
int fusion_count = 0;
// fusion_id for (i, k, js)
std::vector<std::pair<size_t, std::pair<size_t, std::vector<size_t>>>>
fusion_ids;
for (size_t i = 0; i < start_pattern_in_nodes.size(); ++i) {
Node* tmp = start_pattern_out_node[i];
Node* old_tmp = nullptr;
// get correct inner pattern node order.
std::vector<size_t> js;
while (tmp != old_tmp) {
old_tmp = tmp;
for (size_t j = 0; j < inner_pattern_tmp_in.size(); ++j) {
if (inner_pattern_tmp_in[j] == tmp) {
tmp = inner_pattern_out[j];
js.push_back(j);
break;
}
}
}
for (size_t k = 0; k < end_pattern_elt_out.size(); ++k) {
if (tmp == end_pattern_elt_out[k]) {
fusion_ids.push_back(std::make_pair(i, std::make_pair(k, js)));
break;
}
}
}
for (size_t num = 0; num < fusion_ids.size(); ++num) {
int i = fusion_ids[num].first;
int k = fusion_ids[num].second.first;
std::vector<size_t> js = fusion_ids[num].second.second;
std::vector<std::string> ids;
std::vector<std::string> embs;
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
ids.push_back(start_pattern_in_nodes[i][iter].first->Name());
embs.push_back(start_pattern_in_nodes[i][iter].second->Name());
}
for (size_t iter = 0; iter < js.size(); ++iter) {
ids.push_back(inner_pattern_ins[js[iter]].first->Name());
embs.push_back(inner_pattern_ins[js[iter]].second->Name());
}
OpDesc new_op_desc;
new_op_desc.SetType("fused_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids);
new_op_desc.SetInput("Embs", embs);
new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()});
new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()});
new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()});
new_op_desc.SetAttr("epsilon",
end_patter_layernorms[k]->Op()->GetAttr("epsilon"));
auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc);
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second,
embedding_eltwise_layernorm);
}
for (size_t iter = 0; iter < js.size(); ++iter) {
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second,
embedding_eltwise_layernorm);
}
IR_NODE_LINK_TO(end_pattern_biases[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(end_pattern_scales[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(embedding_eltwise_layernorm, end_pattern_out[k]);
// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes;
marked_nodes.insert(start_pattern_remove_nodes[i].begin(),
start_pattern_remove_nodes[i].end());
marked_nodes.insert(end_pattern_remove_nodes[k].begin(),
end_pattern_remove_nodes[k].end());
for (size_t iter = 0; iter < js.size(); ++iter) {
marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(),
inner_pattern_remove_nodes[js[iter]].end());
}
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
}
return fusion_count;
} }
} // namespace patterns } // namespace patterns
void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope(); int fusion_count = patterns::BuildFusion(graph, name_scope_);
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::PreconditionNotMet(
"The scope is null, please initialize the scope first."));
int fusion_count = patterns::BuildFusion(graph, name_scope_, scope);
AddStatis(fusion_count); AddStatis(fusion_count);
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
...@@ -25,35 +26,76 @@ namespace framework { ...@@ -25,35 +26,76 @@ namespace framework {
namespace ir { namespace ir {
namespace patterns { namespace patterns {
struct EmbeddingEltwiseLayerNormPattern : public PatternBase { // detect start pattern.
EmbeddingEltwiseLayerNormPattern(PDPattern* pattern, //
const std::string& name_scope) // in_var emb in_var emb
: PatternBase(pattern, name_scope, "embedding_eltwise_layernorm") {} // | | | |
// lookup_table lookup_table
// | |
// lkt_var lkt_var
// \ /
// elementwise_add
// |
// elt_out_var
//
struct Embedding2Eltwise1Pattern : public PatternBase {
Embedding2Eltwise1Pattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "embedding2_eltwise1") {}
PDNode* operator()(); void operator()();
PATTERN_DECL_NODE(lookup_table1_x); PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table2_x); PATTERN_DECL_NODE(lookup_table2_x);
PATTERN_DECL_NODE(lookup_table3_x);
PATTERN_DECL_NODE(lookup_table1_w); PATTERN_DECL_NODE(lookup_table1_w);
PATTERN_DECL_NODE(lookup_table2_w); PATTERN_DECL_NODE(lookup_table2_w);
PATTERN_DECL_NODE(lookup_table3_w);
PATTERN_DECL_NODE(lookup_table1); PATTERN_DECL_NODE(lookup_table1);
PATTERN_DECL_NODE(lookup_table2); PATTERN_DECL_NODE(lookup_table2);
PATTERN_DECL_NODE(lookup_table3);
PATTERN_DECL_NODE(lookup_table1_out); PATTERN_DECL_NODE(lookup_table1_out);
PATTERN_DECL_NODE(lookup_table2_out); PATTERN_DECL_NODE(lookup_table2_out);
PATTERN_DECL_NODE(lookup_table3_out); PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out);
PATTERN_DECL_NODE(eltwise_add_12); };
PATTERN_DECL_NODE(eltwise_add_12_out);
// detect repeats inner pattern
//
// elt_out_var in_var emb
// \ | |
// \ lookup_table
// \ |
// \ lkt_var
// \ /
// elementwise_add
// |
// elt_out_var
//
struct Embedding1Eltwise1Pattern : public PatternBase {
Embedding1Eltwise1Pattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "embedding1_eltwise1") {}
void operator()();
PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table1_w);
PATTERN_DECL_NODE(lookup_table1);
PATTERN_DECL_NODE(lookup_table1_out);
PATTERN_DECL_NODE(eltwise_add_in);
PATTERN_DECL_NODE(eltwise_add); PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out); PATTERN_DECL_NODE(eltwise_add_out);
};
// detect end pattern
//
// elementwise_add
// |
// elt_out_var
// scale | bias
// \ | /
// layer_norm
//
struct SkipLayerNorm : public PatternBase {
SkipLayerNorm(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "skip_layernorm") {}
void operator()();
PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out);
PATTERN_DECL_NODE(layer_norm); PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_bias); PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale); PATTERN_DECL_NODE(layer_norm_scale);
...@@ -79,6 +121,23 @@ struct EmbeddingEltwiseLayerNormPattern : public PatternBase { ...@@ -79,6 +121,23 @@ struct EmbeddingEltwiseLayerNormPattern : public PatternBase {
// //
// (word, pos, sent, weights_0, weights_1, weights_2, // (word, pos, sent, weights_0, weights_1, weights_2,
// scale, baias) embedding_eltwise_layernorm -> layer_norm_out // scale, baias) embedding_eltwise_layernorm -> layer_norm_out
//
//
// in_var emb_var in_var emb_var in_var emb_var in_var emb_var
// | | | | | | | |
// lookup_table lookup_table lookup_table ... lookup_table
// | | | |
// lkt_var lkt_var lkt_var lkt_var
// \ / | ... |
// elementwise_add | |
// \ / |
// elementwise_add |
// | |
// elt_var /
// \ /
// elementwise_add
// |
// layer_norm
class EmbeddingEltwiseLayerNormFusePass : public FusePassBase { class EmbeddingEltwiseLayerNormFusePass : public FusePassBase {
public: public:
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(SkipLayerNormFusePass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (x, y) elementwise_add -> elementwise_out
// (elementwise_out, scale, bias) layer_norm -> layer_norm_out...
Layers layers;
auto* x0 = layers.data("x0", {1, 256, 1});
auto* x1 = layers.data("x1", {1, 256, 1});
auto* x2 = layers.data("x2", {1, 256, 1});
auto* x3 = layers.data("x3", {1, 256, 1});
auto* emb0 = layers.data("emb0", {18000, 768}, true);
auto* emb1 = layers.data("emb1", {4, 768}, true);
auto* emb2 = layers.data("emb2", {513, 768}, true);
auto* emb3 = layers.data("emb3", {3, 768}, true);
auto* lkt0 = layers.embedding(x0, emb0);
auto* lkt1 = layers.embedding(x1, emb1);
auto* lkt2 = layers.embedding(x2, emb2);
auto* lkt3 = layers.embedding(x3, emb3);
auto* elementwise_out1 = layers.elementwise_add(lkt0, lkt2);
auto* elementwise_out2 = layers.elementwise_add(elementwise_out1, lkt1);
auto* elementwise_out3 = layers.elementwise_add(elementwise_out2, lkt3);
auto* scale = layers.data("scale", {768}, true);
auto* bias = layers.data("bias", {768}, true);
layers.layer_norm(elementwise_out3, scale, bias);
auto* y0 = layers.data("y0", {1, 256, 1});
auto* y1 = layers.data("y1", {1, 256, 1});
auto* y2 = layers.data("y2", {1, 256, 1});
auto* emb0y = layers.data("emb0y", {18000, 768}, true);
auto* emb1y = layers.data("emb1y", {4, 768}, true);
auto* emb2y = layers.data("emb2y", {513, 768}, true);
auto* lkt0y = layers.embedding(y0, emb0y);
auto* lkt1y = layers.embedding(y1, emb1y);
auto* lkt2y = layers.embedding(y2, emb2y);
auto* elementwise_out1y = layers.elementwise_add(lkt0y, lkt2y);
auto* elementwise_out2y = layers.elementwise_add(elementwise_out1y, lkt1y);
auto* scaley = layers.data("scaley", {768}, true);
auto* biasy = layers.data("biasy", {768}, true);
layers.layer_norm(elementwise_out2y, scaley, biasy);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass =
PassRegistry::Instance().Get("embedding_eltwise_layernorm_fuse_pass");
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
int num_fused_nodes_after =
GetNumOpNodes(graph, "fused_embedding_eltwise_layernorm");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 28,
platform::errors::PreconditionNotMet(
"The number of nodes before and after the fuse does "
"not meet expectations"));
PADDLE_ENFORCE_EQ(
num_fused_nodes_after, 2,
platform::errors::PreconditionNotMet(
"The number of fusion nodes does not meet expectations after fuse"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(embedding_eltwise_layernorm_fuse_pass);
...@@ -26,33 +26,19 @@ class EmbeddingEltWiseLayerNormOp : public framework::OperatorWithKernel { ...@@ -26,33 +26,19 @@ class EmbeddingEltWiseLayerNormOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* context) const override { void InferShape(framework::InferShapeContext* context) const override {
PADDLE_ENFORCE_EQ(context->HasInput("WordId"), true, PADDLE_ENFORCE_EQ(context->Inputs("Ids").size(),
context->Inputs("Embs").size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input(WordId) of EmbeddingEltWiseLayerNormOp should " "Two inputs of EmbeddingEltWiseLayerNormOp shoube be "
"not be null.")); "the same size"));
PADDLE_ENFORCE_GE(context->Inputs("Embs").size(), 2UL,
PADDLE_ENFORCE_EQ(
context->HasInput("PosId"), true,
platform::errors::InvalidArgument(
"Input(PosId) of EmbeddingEltWiseLayerNormOp should not be null."));
PADDLE_ENFORCE_EQ(context->HasInput("SentId"), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input(SentId) of EmbeddingEltWiseLayerNormOp should " "Input Embs of EmbeddingEltWiseLayerNormOp should "
"not be null.")); "have at least 2 tensors"));
PADDLE_ENFORCE_GE(context->Inputs("Ids").size(), 2UL,
PADDLE_ENFORCE_EQ(context->HasInput("WordEmb"), true,
platform::errors::InvalidArgument(
"Input(WordEmb) of EmbeddingEltWiseLayerNormOp "
"should not be null."));
PADDLE_ENFORCE_EQ(context->HasInput("PosEmb"), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input(PosEmb) of EmbeddingEltWiseLayerNormOp should " "Input Ids of EmbeddingEltWiseLayerNormOp should "
"not be null.")); "have at least 2 tensors"));
PADDLE_ENFORCE_EQ(context->HasInput("SentEmb"), true,
platform::errors::InvalidArgument(
"Input(SentEmb) of EmbeddingEltWiseLayerNormOp "
"should not be null."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
context->HasInput("Bias"), true, context->HasInput("Bias"), true,
...@@ -70,55 +56,55 @@ class EmbeddingEltWiseLayerNormOp : public framework::OperatorWithKernel { ...@@ -70,55 +56,55 @@ class EmbeddingEltWiseLayerNormOp : public framework::OperatorWithKernel {
"Output(Out) of EmbeddingEltWiseLayerNormOp should not be null.")); "Output(Out) of EmbeddingEltWiseLayerNormOp should not be null."));
// batch * seq_len * 1 // batch * seq_len * 1
auto dims_word_id = context->GetInputDim("WordId"); auto ids_dims = context->GetInputsDim("Ids");
// word_num * hidden // word_num * hidden
auto dims_word_emb = context->GetInputDim("WordEmb"); auto embs_dims = context->GetInputsDim("Embs");
auto dims_pos_emb = context->GetInputDim("PosEmb");
auto dims_sent_emb = context->GetInputDim("SentEmb");
// hidden // hidden
auto dims_bias = context->GetInputDim("Bias"); auto dims_bias = context->GetInputDim("Bias");
PADDLE_ENFORCE_EQ( int batch = ids_dims[0][0];
dims_word_emb[1], dims_bias[0], int seq_len = ids_dims[0][1];
platform::errors::InvalidArgument( int hidden = embs_dims[0][1];
"The second dims (%d) of the Word Embedding should be equal " for (size_t i = 0; i < embs_dims.size(); ++i) {
"to the Bias's size(%d).", PADDLE_ENFORCE_EQ(embs_dims[i].size(), 2,
dims_word_emb[1], dims_bias[0])); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(dims_word_emb.size(), 2, "The Emb dim's size shoule be 2, but found %d.",
platform::errors::InvalidArgument( embs_dims[i].size()));
"The WordEmb dim's size shoule be 2, but found %d.", PADDLE_ENFORCE_EQ(
dims_word_emb.size())); embs_dims[i][1], dims_bias[0],
PADDLE_ENFORCE_EQ(dims_pos_emb.size(), 2, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "The second dims (%d) of the Embedding should be equal "
"The PosEmb dim's size shoule be 2, but found %d.", "to the Bias's size(%d).",
dims_pos_emb.size())); embs_dims[i][1], dims_bias[0]));
PADDLE_ENFORCE_EQ(dims_sent_emb.size(), 2, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( embs_dims[i][1], hidden,
"The SentEmb dim's size shoule be 2, but found %d.", platform::errors::InvalidArgument(
dims_sent_emb.size())); "The Emb first dim size(%d) shoule equal to hidden (%d).",
PADDLE_ENFORCE_EQ( embs_dims[i][1], hidden));
dims_word_emb[1], dims_pos_emb[1], }
platform::errors::InvalidArgument(
"The WordEmb first dim size(%d) shoule equal to PosEmb ones(%d).",
dims_word_emb[1], dims_pos_emb[1]));
PADDLE_ENFORCE_EQ(
dims_word_emb[1], dims_sent_emb[1],
platform::errors::InvalidArgument(
"The WordEmb first dim size(%d) shoule equal to SentEmb ones(%d).",
dims_word_emb[1], dims_sent_emb[1]));
int batch = dims_word_id[0];
int seq_len = dims_word_id[1];
int hidden = dims_word_emb[1];
auto dim_output = framework::make_ddim({batch, seq_len, hidden}); auto dim_output = framework::make_ddim({batch, seq_len, hidden});
context->SetOutputDim("Out", dim_output); context->SetOutputDim("Out", dim_output);
context->ShareLoD("WordId", /*->*/ "Out"); context->ShareLoD("Ids", /*->*/ "Out");
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "WordEmb"); auto inputs = ctx.MultiInput<framework::Tensor>("Embs");
return framework::OpKernelType(data_type, ctx.device_context()); auto input_data_type = framework::proto::VarType::Type(0);
bool flag = 0;
for (auto* input : inputs) {
if (input->IsInitialized() && input->numel() > 0) {
input_data_type = input->type();
flag = 1;
break;
}
}
if (flag == 0) {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"All Inputs of fused_embedding_eltwise_layernorm OP are Empty!"));
}
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -126,15 +112,10 @@ class EmbeddingEltWiseLayerNormOpMaker ...@@ -126,15 +112,10 @@ class EmbeddingEltWiseLayerNormOpMaker
: public framework::OpProtoAndCheckerMaker { : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("WordId", "The word id input of EmbeddingEltWiseLayerNorm op"); AddInput("Ids", "Input id tensors of EmbeddingEltWiseLayerNorm op")
AddInput("PosId", "The position id input of EmbeddingEltWiseLayerNorm op"); .AsDuplicable();
AddInput("SentId", "The sentence id input of EmbeddingEltWiseLayerNorm op"); AddInput("Embs", "Input emb tensors of EmbeddingEltWiseLayerNorm op")
AddInput("WordEmb", .AsDuplicable();
"The Word embedding input of EmbeddingEltWiseLayerNorm op");
AddInput("PosEmb",
"The Position embedding input of EmbeddingEltWiseLayerNorm op");
AddInput("SentEmb",
"The Sent embedding input of EmbeddingEltWiseLayerNorm op");
AddInput("Bias", "The LayerNorm Bias of EmbeddingEltWiseLayerNorm op"); AddInput("Bias", "The LayerNorm Bias of EmbeddingEltWiseLayerNorm op");
AddInput("Scale", "The LayerNorm Scale of EmbeddingEltWiseLayerNorm op"); AddInput("Scale", "The LayerNorm Scale of EmbeddingEltWiseLayerNorm op");
AddOutput("Out", "The output of EmbeddingEltWiseLayerNorm op"); AddOutput("Out", "The output of EmbeddingEltWiseLayerNorm op");
...@@ -157,10 +138,11 @@ class EmbeddingEltWiseLayerNormOpMaker ...@@ -157,10 +138,11 @@ class EmbeddingEltWiseLayerNormOpMaker
EmbeddingEltWiseLayerNorm Operator. EmbeddingEltWiseLayerNorm Operator.
This op is used for optimize the following structure in ernie model. This op is used for optimize the following structure in ernie model.
wordid -> lookup_table_op -> word id1 -> lookup_table_op -> data1
posid -> lookup_table_op -> pos id2 -> lookup_table_op -> data2
sentdid -> lookup_table_op -> sent ...
word + pos + sent -> Y idn -> lookup_table_op -> data_n
data1 + data2 + ... + data_n -> Y
Y -> layer_norm -> Out Y -> layer_norm -> Out
Not suggest to use in other case except has same structure as ernie. Not suggest to use in other case except has same structure as ernie.
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <paddle/fluid/platform/device_context.h> #include <paddle/fluid/platform/device_context.h>
#include <algorithm> #include <algorithm>
#include <cub/cub.cuh> // NOLINT #include <cub/cub.cuh> // NOLINT
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
...@@ -57,32 +58,28 @@ __device__ inline void LayerNorm(const cv2<T> &thread_data, const int ld, ...@@ -57,32 +58,28 @@ __device__ inline void LayerNorm(const cv2<T> &thread_data, const int ld,
} }
template <typename T, unsigned TPB> template <typename T, unsigned TPB>
__global__ void EmbEltwiseLayernormKernel( __global__ void EmbEltwiseLayernormKernel(int hidden, const int64_t *ids,
int hidden, const int64_t *word_id_d, const int64_t *pos_id_d, const T *scale, const T *bias,
const int64_t *sent_id_d, const T *scale, const T *bias, const T *word_emb, const int64_t *embs, T *output,
const T *pos_emb, const T *sent_emb, T *output, float eps) { float eps, int input_num) {
cub::Sum pair_sum; cub::Sum pair_sum;
// blockIdx.x: position in the sequence // blockIdx.x: position in the sequence
// blockIdx.y: batch // blockIdx.y: batch
// gridDim.x: Seq // gridDim.x: Seq
// gridDim.y: Batch // gridDim.y: Batch
__shared__ int64_t word_id;
__shared__ int64_t pos_id; extern __shared__ int64_t array_id[];
__shared__ int64_t sent_id;
const T rhidden = T(1.f) / T(hidden); const T rhidden = T(1.f) / T(hidden);
const int64_t seq_pos = blockIdx.y + blockIdx.x * gridDim.y; const int64_t seq_pos = blockIdx.y + blockIdx.x * gridDim.y;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
word_id = word_id_d[seq_pos]; for (int i = 0; i < input_num; ++i) {
pos_id = pos_id_d[seq_pos]; const int64_t *ids_p = reinterpret_cast<const int64_t *>(ids[i]);
sent_id = sent_id_d[seq_pos]; array_id[i] = ids_p[seq_pos];
}
} }
__syncthreads(); __syncthreads();
// load word, pos, sentence embeddings and add them toghether
const int64_t woffset = word_id * hidden;
const int64_t poffset = pos_id * hidden;
const int64_t soffset = sent_id * hidden;
const int64_t out_offset = seq_pos * hidden; const int64_t out_offset = seq_pos * hidden;
cv2<T> thread_data; cv2<T> thread_data;
...@@ -91,10 +88,10 @@ __global__ void EmbEltwiseLayernormKernel( ...@@ -91,10 +88,10 @@ __global__ void EmbEltwiseLayernormKernel(
#pragma unroll #pragma unroll
for (int it = threadIdx.x; it < hidden; it += TPB) { for (int it = threadIdx.x; it < hidden; it += TPB) {
const T w(word_emb[woffset + it]); T val = 0;
const T p(pos_emb[poffset + it]); for (int i = 0; i < input_num; ++i) {
const T s(sent_emb[soffset + it]); val += reinterpret_cast<const T *>(embs[i])[array_id[i] * hidden + it];
const T val = w + s + p; }
output[out_offset + it] = val; output[out_offset + it] = val;
const T rhiddenval = rhidden * val; const T rhiddenval = rhidden * val;
...@@ -112,47 +109,58 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> { ...@@ -112,47 +109,58 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto *word_id = context.Input<framework::Tensor>("WordId"); auto &device_ctx = context.template device_context<DeviceContext>();
auto *pos_id = context.Input<framework::Tensor>("PosId"); auto ids = context.MultiInput<framework::Tensor>("Ids");
auto *sent_id = context.Input<framework::Tensor>("SentId"); auto embs = context.MultiInput<framework::Tensor>("Embs");
int input_num = static_cast<int>(ids.size());
auto *word_emb = context.Input<framework::Tensor>("WordEmb");
auto *pos_emb = context.Input<framework::Tensor>("PosEmb"); framework::Tensor in_ids_(framework::proto::VarType::INT64),
auto *sent_emb = context.Input<framework::Tensor>("SentEmb"); in_embs_(framework::proto::VarType::INT64);
framework::DDim in_dim{input_num};
int device_id;
cudaGetDevice(&device_id);
in_ids_.Resize(in_dim);
in_embs_.Resize(in_dim);
int64_t *in_ids_d =
in_ids_.mutable_data<int64_t>(platform::CUDAPlace(device_id));
int64_t *in_embs_d =
in_embs_.mutable_data<int64_t>(platform::CUDAPlace(device_id));
std::vector<int64_t> in1s, in2s;
for (int i = 0; i < input_num; ++i) {
in1s.push_back(reinterpret_cast<uintptr_t>(ids[i]->data<int64_t>()));
in2s.push_back(reinterpret_cast<uintptr_t>(embs[i]->data<T>()));
}
cudaMemcpyAsync(in_ids_d, in1s.data(), sizeof(int64_t) * input_num,
cudaMemcpyHostToDevice, device_ctx.stream());
cudaMemcpyAsync(in_embs_d, in2s.data(), sizeof(int64_t) * input_num,
cudaMemcpyHostToDevice, device_ctx.stream());
auto *bias = context.Input<framework::Tensor>("Bias"); auto *bias = context.Input<framework::Tensor>("Bias");
auto *scale = context.Input<framework::Tensor>("Scale"); auto *scale = context.Input<framework::Tensor>("Scale");
auto *out = context.Output<framework::Tensor>("Out"); auto *out = context.Output<framework::Tensor>("Out");
auto *word_id_d = word_id->data<int64_t>(); // should be (B * S * hidden)
auto *pos_id_d = pos_id->data<int64_t>(); auto id0_dims = ids[0]->dims();
auto *sent_id_d = sent_id->data<int64_t>(); auto emb0_dims = embs[0]->dims();
auto *word_emb_d = word_emb->data<T>(); int batch = id0_dims[0];
auto *pos_emb_d = pos_emb->data<T>(); int seq_len = id0_dims[1];
auto *sent_emb_d = sent_emb->data<T>(); int hidden = emb0_dims[1];
auto *bias_d = bias->data<T>(); auto *bias_d = bias->data<T>();
auto *scale_d = scale->data<T>(); auto *scale_d = scale->data<T>();
auto *output_d = out->mutable_data<T>(context.GetPlace()); auto *output_d = out->mutable_data<T>(context.GetPlace());
// compute q*k with eltadd
auto &device_ctx = context.template device_context<DeviceContext>();
float eps = context.Attr<float>("epsilon"); float eps = context.Attr<float>("epsilon");
// should be (B * S * hidden)
auto word_id_dims = word_id->dims();
auto word_emb_dims = word_emb->dims();
int batch = word_id_dims[0];
int seq_len = word_id_dims[1];
int hidden = word_emb_dims[1];
const unsigned tpb = 256; const unsigned tpb = 256;
const dim3 grid(seq_len, batch, 1); const dim3 grid(seq_len, batch, 1);
const dim3 block(tpb, 1, 1); const dim3 block(tpb, 1, 1);
EmbEltwiseLayernormKernel<T, tpb><<<grid, block, 0, device_ctx.stream()>>>( int shared_bytes = input_num * sizeof(int64_t);
hidden, word_id_d, pos_id_d, sent_id_d, scale_d, bias_d, word_emb_d, EmbEltwiseLayernormKernel<
pos_emb_d, sent_emb_d, output_d, eps); T, tpb><<<grid, block, shared_bytes, device_ctx.stream()>>>(
hidden, in_ids_d, scale_d, bias_d, in_embs_d, output_d, eps, input_num);
} }
}; };
......
...@@ -48,6 +48,39 @@ class EmbEltwiseLayerNormFusePassTest(PassTest): ...@@ -48,6 +48,39 @@ class EmbEltwiseLayerNormFusePassTest(PassTest):
add2 = fluid.layers.elementwise_add(add1, sent_emb) add2 = fluid.layers.elementwise_add(add1, sent_emb)
hidden1 = fluid.layers.layer_norm(input=add2, begin_norm_axis=2) hidden1 = fluid.layers.layer_norm(input=add2, begin_norm_axis=2)
id1 = fluid.layers.data(
name="id1",
shape=[1, 128, 1],
dtype="int64",
append_batch_size=False)
id2 = fluid.layers.data(
name="id2",
shape=[1, 128, 1],
dtype="int64",
append_batch_size=False)
id3 = fluid.layers.data(
name="id3",
shape=[1, 128, 1],
dtype="int64",
append_batch_size=False)
id4 = fluid.layers.data(
name="id4",
shape=[1, 128, 1],
dtype="int64",
append_batch_size=False)
emb1 = fluid.layers.embedding(
input=id1, size=(128, 768), dtype='float32')
emb2 = fluid.layers.embedding(
input=id2, size=(128, 768), dtype='float32')
emb3 = fluid.layers.embedding(
input=id3, size=(128, 768), dtype='float32')
emb4 = fluid.layers.embedding(
input=id4, size=(128, 768), dtype='float32')
add_1 = fluid.layers.elementwise_add(emb1, emb2)
add_2 = fluid.layers.elementwise_add(add_1, emb3)
add_3 = fluid.layers.elementwise_add(add_2, emb4)
hidden_1 = fluid.layers.layer_norm(input=add_3, begin_norm_axis=2)
self.feeds = { self.feeds = {
"word_id": np.random.randint( "word_id": np.random.randint(
low=0, high=128, size=(1, 128, 1)).astype("int64"), low=0, high=128, size=(1, 128, 1)).astype("int64"),
...@@ -55,11 +88,19 @@ class EmbEltwiseLayerNormFusePassTest(PassTest): ...@@ -55,11 +88,19 @@ class EmbEltwiseLayerNormFusePassTest(PassTest):
low=0, high=128, size=(1, 128, 1)).astype("int64"), low=0, high=128, size=(1, 128, 1)).astype("int64"),
"sent_id": np.random.randint( "sent_id": np.random.randint(
low=0, high=128, size=(1, 128, 1)).astype("int64"), low=0, high=128, size=(1, 128, 1)).astype("int64"),
"id1": np.random.randint(
low=0, high=128, size=(1, 128, 1)).astype("int64"),
"id2": np.random.randint(
low=0, high=128, size=(1, 128, 1)).astype("int64"),
"id3": np.random.randint(
low=0, high=128, size=(1, 128, 1)).astype("int64"),
"id4": np.random.randint(
low=0, high=128, size=(1, 128, 1)).astype("int64"),
} }
self.fetch_list = [hidden1] self.fetch_list = [hidden1, hidden_1]
self.pass_names = "embedding_eltwise_layernorm_fuse_pass" self.pass_names = "embedding_eltwise_layernorm_fuse_pass"
self.fused_op_type = "fused_embedding_eltwise_layernorm" self.fused_op_type = "fused_embedding_eltwise_layernorm"
self.num_fused_ops = 1 self.num_fused_ops = 2
def test_check_output(self): def test_check_output(self):
use_gpu_set = [True] use_gpu_set = [True]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册