diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 8b5f5aa888ab04fd2dfb7722c55b44ca26211a4d..f90715965ea69f5947c95a37b05d1b638457cec8 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -130,6 +130,7 @@ cc_test(test_skip_layernorm_fuse_pass SRCS skip_layernorm_fuse_pass_tester.cc DE cc_test(test_multihead_matmul_fuse_pass SRCS multihead_matmul_fuse_pass_tester.cc DEPS multihead_matmul_fuse_pass) cc_test(test_conv_bn_fuse_pass SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_fuse_pass) if(WITH_GPU) + cc_test(test_embedding_eltwise_layernorm_fuse_pass SRCS embedding_eltwise_layernorm_fuse_pass_tester.cc DEPS embedding_eltwise_layernorm_fuse_pass) cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass) endif() if(NOT WIN32) diff --git a/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc index d02688dfd13a91d04f317a4db58af7b3f22450bd..215d43a2f06afb57d625b0f891c45a77296deb3b 100644 --- a/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc @@ -25,180 +25,76 @@ namespace framework { namespace ir { namespace patterns { -static int BuildFusion(Graph* graph, const std::string& name_scope, - const Scope* scope) { - GraphPatternDetector gpd; - auto* pattern = gpd.mutable_pattern(); - - // Create pattern. - EmbeddingEltwiseLayerNormPattern emb_eltwise_layernorm_pattern(pattern, - name_scope); - emb_eltwise_layernorm_pattern(); - - int fusion_count{0}; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, - emb_eltwise_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_x, lookup_table2_x, - emb_eltwise_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(lookup_table3_x, lookup_table3_x, - emb_eltwise_layernorm_pattern); - - GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, - emb_eltwise_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_w, lookup_table2_w, - emb_eltwise_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(lookup_table3_w, lookup_table3_w, - emb_eltwise_layernorm_pattern); - - GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, - emb_eltwise_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(lookup_table2, lookup_table2, - emb_eltwise_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(lookup_table3, lookup_table3, - emb_eltwise_layernorm_pattern); - - GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out, - emb_eltwise_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_out, lookup_table2_out, - emb_eltwise_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(lookup_table3_out, lookup_table3_out, - emb_eltwise_layernorm_pattern); - - GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_12, eltwise_add_12, - emb_eltwise_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_12_out, eltwise_add_12_out, - emb_eltwise_layernorm_pattern); - - GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, - emb_eltwise_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, - emb_eltwise_layernorm_pattern); - - GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, - emb_eltwise_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, - emb_eltwise_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, - emb_eltwise_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale, - emb_eltwise_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, - emb_eltwise_layernorm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, - emb_eltwise_layernorm_pattern); - - auto get_persist_tensor_dims = [&](std::string name) -> framework::DDim { - auto* var = scope->FindVar(name); - PADDLE_ENFORCE_NOT_NULL(var, - platform::errors::PreconditionNotMet( - "Cant not found the %d var in scope.", name)); - return var->GetMutable()->dims(); - }; - - // Check the weight dims. - auto word_emb_dims = get_persist_tensor_dims(lookup_table1_w->Name()); - auto pos_emb_dims = get_persist_tensor_dims(lookup_table2_w->Name()); - auto sent_emb_dims = get_persist_tensor_dims(lookup_table3_w->Name()); - if (word_emb_dims.size() != 2 || pos_emb_dims.size() != 2 || - sent_emb_dims.size() != 2 || word_emb_dims[1] != pos_emb_dims[1] || - word_emb_dims[1] != sent_emb_dims[1]) { - return; - } - - OpDesc new_op_desc; - new_op_desc.SetType("fused_embedding_eltwise_layernorm"); - new_op_desc.SetInput("WordId", {lookup_table1_x->Name()}); - new_op_desc.SetInput("PosId", {lookup_table2_x->Name()}); - new_op_desc.SetInput("SentId", {lookup_table3_x->Name()}); - - new_op_desc.SetInput("WordEmb", {lookup_table1_w->Name()}); - new_op_desc.SetInput("PosEmb", {lookup_table2_w->Name()}); - new_op_desc.SetInput("SentEmb", {lookup_table3_w->Name()}); - - new_op_desc.SetInput("Bias", {layer_norm_bias->Name()}); - new_op_desc.SetInput("Scale", {layer_norm_scale->Name()}); - new_op_desc.SetOutput("Out", {layer_norm_out->Name()}); - new_op_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon")); - - auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc); - IR_NODE_LINK_TO(lookup_table1_x, embedding_eltwise_layernorm); - IR_NODE_LINK_TO(lookup_table2_x, embedding_eltwise_layernorm); - IR_NODE_LINK_TO(lookup_table3_x, embedding_eltwise_layernorm); - - IR_NODE_LINK_TO(lookup_table1_w, embedding_eltwise_layernorm); - IR_NODE_LINK_TO(lookup_table2_w, embedding_eltwise_layernorm); - IR_NODE_LINK_TO(lookup_table3_w, embedding_eltwise_layernorm); - IR_NODE_LINK_TO(layer_norm_bias, embedding_eltwise_layernorm); - IR_NODE_LINK_TO(layer_norm_scale, embedding_eltwise_layernorm); - IR_NODE_LINK_TO(embedding_eltwise_layernorm, layer_norm_out); - - std::unordered_set marked_nodes( - {lookup_table1, lookup_table2, lookup_table3, lookup_table1_out, - lookup_table2_out, lookup_table3_out, eltwise_add_12, - eltwise_add_12_out, eltwise_add, eltwise_add_out, layer_norm, - layer_norm_mean, layer_norm_variance}); - // Remove unneeded nodes. - GraphSafeRemoveNodes(graph, marked_nodes); - ++fusion_count; - }; - gpd(graph, handler); - - return fusion_count; +static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name, + const std::string& arg, + bool is_persist = false) { + PDNode* node = + pattern->NewNode(name)->assert_is_op_input("lookup_table", arg); + if (is_persist) return node->assert_is_persistable_var(); + return node; } - -PDNode* EmbeddingEltwiseLayerNormPattern::operator()() { - // Create shared nodes. - auto create_emb_vars = [&](const std::string& name, const std::string& arg, - bool is_persist = false) -> PDNode* { - PDNode* node = pattern->NewNode(name) - ->assert_is_op_input("lookup_table", arg) - ->AsInput(); - if (is_persist) return node->assert_is_persistable_var(); - return node; - }; - - auto create_emb_out_vars = [&](const std::string& name, - const std::string& arg) -> PDNode* { - PDNode* node = pattern->NewNode(name) - ->AsIntermediate() - ->assert_is_op_output("lookup_table") - ->assert_is_op_input("elementwise_add", arg); - return node; - }; - - auto* lookup_table1_x = create_emb_vars(lookup_table1_x_repr(), "Ids"); - auto* lookup_table2_x = create_emb_vars(lookup_table2_x_repr(), "Ids"); - auto* lookup_table3_x = create_emb_vars(lookup_table3_x_repr(), "Ids"); - auto* lookup_table1_w = create_emb_vars(lookup_table1_w_repr(), "W", true); - auto* lookup_table2_w = create_emb_vars(lookup_table2_w_repr(), "W", true); - auto* lookup_table3_w = create_emb_vars(lookup_table3_w_repr(), "W", true); - +static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name, + const std::string& arg) { + PDNode* node = pattern->NewNode(name) + ->assert_is_op_output("lookup_table") + ->assert_is_op_input("elementwise_add", arg); + return node; +} +void Embedding2Eltwise1Pattern::operator()() { + auto* lookup_table1_x = + create_emb_vars(pattern, lookup_table1_x_repr(), "Ids"); + auto* lookup_table2_x = + create_emb_vars(pattern, lookup_table2_x_repr(), "Ids"); + auto* lookup_table1_w = + create_emb_vars(pattern, lookup_table1_w_repr(), "W", true); + auto* lookup_table2_w = + create_emb_vars(pattern, lookup_table2_w_repr(), "W", true); auto* lookup_table1 = pattern->NewNode(lookup_table1_repr())->assert_is_op("lookup_table"); auto* lookup_table2 = pattern->NewNode(lookup_table2_repr())->assert_is_op("lookup_table"); - auto* lookup_table3 = - pattern->NewNode(lookup_table3_repr())->assert_is_op("lookup_table"); - - auto* lookup_table1_out = create_emb_out_vars(lookup_table1_out_repr(), "X"); - auto* lookup_table2_out = create_emb_out_vars(lookup_table2_out_repr(), "Y"); - auto* lookup_table3_out = create_emb_out_vars(lookup_table3_out_repr(), "Y"); - - auto* eltwise_add_12 = - pattern->NewNode(eltwise_add_12_repr())->assert_is_op("elementwise_add"); - auto* eltwise_add_12_out = pattern->NewNode(eltwise_add_12_out_repr()) - ->AsIntermediate() - ->assert_is_op_output("elementwise_add") - ->assert_is_op_input("elementwise_add", "X"); - + auto* lookup_table1_out = + create_emb_out_vars(pattern, lookup_table1_out_repr(), "X"); + auto* lookup_table2_out = + create_emb_out_vars(pattern, lookup_table2_out_repr(), "Y"); auto* eltwise_add = pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add"); auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr()) - ->AsIntermediate() ->assert_is_op_output("elementwise_add"); - + lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w}) + .LinksTo({lookup_table1_out}); + lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w}) + .LinksTo({lookup_table2_out}); + eltwise_add->LinksFrom({lookup_table1_out, lookup_table2_out}) + .LinksTo({eltwise_add_out}); +} +void Embedding1Eltwise1Pattern::operator()() { + auto* lookup_table1_x = + create_emb_vars(pattern, lookup_table1_x_repr(), "Ids"); + auto* lookup_table1_w = + create_emb_vars(pattern, lookup_table1_w_repr(), "W", true); + auto* lookup_table1 = + pattern->NewNode(lookup_table1_repr())->assert_is_op("lookup_table"); + auto* lookup_table1_out = + create_emb_out_vars(pattern, lookup_table1_out_repr(), "Y"); + auto* eltwise_add = + pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add"); + auto* eltwise_add_in = pattern->NewNode(eltwise_add_in_repr()) + ->assert_is_op_input("elementwise_add", "X") + ->assert_is_op_output("elementwise_add"); + auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr()) + ->assert_is_op_output("elementwise_add"); + lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w}) + .LinksTo({lookup_table1_out}); + eltwise_add->LinksFrom({lookup_table1_out, eltwise_add_in}) + .LinksTo({eltwise_add_out}); +} +void SkipLayerNorm::operator()() { + auto* eltwise_add = + pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add"); + auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr()) + ->assert_is_op_output("elementwise_add") + ->assert_is_op_input("layer_norm", "X"); auto* layer_norm = pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); auto* layer_norm_out = pattern->NewNode(layer_norm_out_repr()) @@ -212,7 +108,6 @@ PDNode* EmbeddingEltwiseLayerNormPattern::operator()() { ->AsInput() ->assert_is_persistable_var() ->assert_is_op_input("layer_norm", "Scale"); - auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) ->AsOutput() ->assert_is_op_output("layer_norm", "Mean"); @@ -220,33 +115,214 @@ PDNode* EmbeddingEltwiseLayerNormPattern::operator()() { pattern->NewNode(layer_norm_variance_repr()) ->AsOutput() ->assert_is_op_output("layer_norm", "Variance"); - - // Link all nodes together - lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w}) - .LinksTo({lookup_table1_out}); - lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w}) - .LinksTo({lookup_table2_out}); - lookup_table3->LinksFrom({lookup_table3_x, lookup_table3_w}) - .LinksTo({lookup_table3_out}); - eltwise_add_12->LinksFrom({lookup_table1_out, lookup_table2_out}) - .LinksTo({eltwise_add_12_out}); - eltwise_add->LinksFrom({lookup_table3_out, eltwise_add_12_out}) - .LinksTo({eltwise_add_out}); + eltwise_add->LinksTo({eltwise_add_out}); layer_norm ->LinksFrom({eltwise_add_out, layer_norm_bias_var, layer_norm_scale_var}) .LinksTo({layer_norm_out, layer_norm_mean_var, layer_norm_variance_var}); - return layer_norm_out; +} +static int BuildFusion(Graph* graph, const std::string& name_scope + /*const Scope* scope*/) { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + std::vector>> start_pattern_in_nodes; + std::vector start_pattern_out_node; + std::vector> start_pattern_remove_nodes; + + // Create pattern. + Embedding2Eltwise1Pattern start_pattern(pattern, name_scope + "/start"); + start_pattern(); + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_x, lookup_table2_x, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_w, lookup_table2_w, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table2, lookup_table2, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out, + start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_out, lookup_table2_out, + start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, start_pattern); + std::vector> ins; + ins.push_back(std::make_pair(lookup_table1_x, lookup_table1_w)); + ins.push_back(std::make_pair(lookup_table2_x, lookup_table2_w)); + start_pattern_in_nodes.push_back(ins); + start_pattern_out_node.push_back(eltwise_add_out); + + std::unordered_set rm_nodes; + rm_nodes.insert({lookup_table1, lookup_table2, lookup_table1_out, + lookup_table2_out, eltwise_add, eltwise_add_out}); + start_pattern_remove_nodes.push_back(rm_nodes); + }; + gpd(graph, handler); + + std::vector> inner_pattern_ins; + std::vector inner_pattern_tmp_in; + std::vector inner_pattern_out; + std::vector> inner_pattern_remove_nodes; + + GraphPatternDetector gpd2; + auto* pattern2 = gpd2.mutable_pattern(); + Embedding1Eltwise1Pattern second_pattern(pattern2, name_scope + "/second"); + second_pattern(); + auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out, + second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_in, eltwise_add_in, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, second_pattern); + auto in = std::make_pair(lookup_table1_x, lookup_table1_w); + inner_pattern_ins.push_back(in); + inner_pattern_tmp_in.push_back(eltwise_add_in); + inner_pattern_out.push_back(eltwise_add_out); + + std::unordered_set rm_nodes; + rm_nodes.insert( + {lookup_table1, lookup_table1_out, eltwise_add, eltwise_add_out}); + inner_pattern_remove_nodes.push_back(rm_nodes); + }; + gpd2(graph, handler2); + + std::vector end_pattern_elt_out; + std::vector end_pattern_scales; + std::vector end_pattern_biases; + std::vector end_pattern_out; + std::vector end_patter_layernorms; + std::vector> end_pattern_remove_nodes; + GraphPatternDetector gpd3; + auto* pattern3 = gpd3.mutable_pattern(); + SkipLayerNorm skip_layernorm_pattern(pattern3, name_scope + "/third"); + skip_layernorm_pattern(); + auto handler3 = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, + skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, + skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, + skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale, + skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, + skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance, + skip_layernorm_pattern); + end_pattern_elt_out.push_back(eltwise_add_out); + std::unordered_set rm_nodes; + rm_nodes.insert({layer_norm, layer_norm_mean, layer_norm_variance}); + end_pattern_remove_nodes.push_back(rm_nodes); + end_pattern_biases.push_back(layer_norm_bias); + end_pattern_scales.push_back(layer_norm_scale); + end_pattern_out.push_back(layer_norm_out); + end_patter_layernorms.push_back(layer_norm); + }; + gpd3(graph, handler3); + + if (start_pattern_in_nodes.empty() || end_pattern_elt_out.empty()) { + return 0; + } + // only reserve the subgraphs that in connected domains. + int fusion_count = 0; + // fusion_id for (i, k, js) + std::vector>>> + fusion_ids; + for (size_t i = 0; i < start_pattern_in_nodes.size(); ++i) { + Node* tmp = start_pattern_out_node[i]; + Node* old_tmp = nullptr; + // get correct inner pattern node order. + std::vector js; + while (tmp != old_tmp) { + old_tmp = tmp; + for (size_t j = 0; j < inner_pattern_tmp_in.size(); ++j) { + if (inner_pattern_tmp_in[j] == tmp) { + tmp = inner_pattern_out[j]; + js.push_back(j); + break; + } + } + } + + for (size_t k = 0; k < end_pattern_elt_out.size(); ++k) { + if (tmp == end_pattern_elt_out[k]) { + fusion_ids.push_back(std::make_pair(i, std::make_pair(k, js))); + break; + } + } + } + + for (size_t num = 0; num < fusion_ids.size(); ++num) { + int i = fusion_ids[num].first; + int k = fusion_ids[num].second.first; + std::vector js = fusion_ids[num].second.second; + + std::vector ids; + std::vector embs; + for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) { + ids.push_back(start_pattern_in_nodes[i][iter].first->Name()); + embs.push_back(start_pattern_in_nodes[i][iter].second->Name()); + } + for (size_t iter = 0; iter < js.size(); ++iter) { + ids.push_back(inner_pattern_ins[js[iter]].first->Name()); + embs.push_back(inner_pattern_ins[js[iter]].second->Name()); + } + OpDesc new_op_desc; + new_op_desc.SetType("fused_embedding_eltwise_layernorm"); + new_op_desc.SetInput("Ids", ids); + new_op_desc.SetInput("Embs", embs); + new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()}); + new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()}); + new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()}); + new_op_desc.SetAttr("epsilon", + end_patter_layernorms[k]->Op()->GetAttr("epsilon")); + auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc); + + for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) { + IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first, + embedding_eltwise_layernorm); + IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second, + embedding_eltwise_layernorm); + } + for (size_t iter = 0; iter < js.size(); ++iter) { + IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first, + embedding_eltwise_layernorm); + IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second, + embedding_eltwise_layernorm); + } + IR_NODE_LINK_TO(end_pattern_biases[k], embedding_eltwise_layernorm); + IR_NODE_LINK_TO(end_pattern_scales[k], embedding_eltwise_layernorm); + IR_NODE_LINK_TO(embedding_eltwise_layernorm, end_pattern_out[k]); + + // Remove unneeded nodes. + std::unordered_set marked_nodes; + marked_nodes.insert(start_pattern_remove_nodes[i].begin(), + start_pattern_remove_nodes[i].end()); + marked_nodes.insert(end_pattern_remove_nodes[k].begin(), + end_pattern_remove_nodes[k].end()); + for (size_t iter = 0; iter < js.size(); ++iter) { + marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(), + inner_pattern_remove_nodes[js[iter]].end()); + } + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + } + + return fusion_count; } } // namespace patterns void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { FusePassBase::Init(name_scope_, graph); - auto* scope = param_scope(); - PADDLE_ENFORCE_NOT_NULL( - scope, platform::errors::PreconditionNotMet( - "The scope is null, please initialize the scope first.")); - int fusion_count = patterns::BuildFusion(graph, name_scope_, scope); + int fusion_count = patterns::BuildFusion(graph, name_scope_); AddStatis(fusion_count); } diff --git a/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.h index 8311badb622c75b7d02d8940d184136929300f29..644eb1cf89221c4e6e22e3d767b4b802702d7b88 100644 --- a/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.h +++ b/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.h @@ -16,6 +16,7 @@ #include #include +#include #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" @@ -25,35 +26,76 @@ namespace framework { namespace ir { namespace patterns { -struct EmbeddingEltwiseLayerNormPattern : public PatternBase { - EmbeddingEltwiseLayerNormPattern(PDPattern* pattern, - const std::string& name_scope) - : PatternBase(pattern, name_scope, "embedding_eltwise_layernorm") {} +// detect start pattern. +// +// in_var emb in_var emb +// | | | | +// lookup_table lookup_table +// | | +// lkt_var lkt_var +// \ / +// elementwise_add +// | +// elt_out_var +// +struct Embedding2Eltwise1Pattern : public PatternBase { + Embedding2Eltwise1Pattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "embedding2_eltwise1") {} - PDNode* operator()(); + void operator()(); PATTERN_DECL_NODE(lookup_table1_x); PATTERN_DECL_NODE(lookup_table2_x); - PATTERN_DECL_NODE(lookup_table3_x); - PATTERN_DECL_NODE(lookup_table1_w); PATTERN_DECL_NODE(lookup_table2_w); - PATTERN_DECL_NODE(lookup_table3_w); - PATTERN_DECL_NODE(lookup_table1); PATTERN_DECL_NODE(lookup_table2); - PATTERN_DECL_NODE(lookup_table3); - PATTERN_DECL_NODE(lookup_table1_out); PATTERN_DECL_NODE(lookup_table2_out); - PATTERN_DECL_NODE(lookup_table3_out); - - PATTERN_DECL_NODE(eltwise_add_12); - PATTERN_DECL_NODE(eltwise_add_12_out); + PATTERN_DECL_NODE(eltwise_add); + PATTERN_DECL_NODE(eltwise_add_out); +}; +// detect repeats inner pattern +// +// elt_out_var in_var emb +// \ | | +// \ lookup_table +// \ | +// \ lkt_var +// \ / +// elementwise_add +// | +// elt_out_var +// +struct Embedding1Eltwise1Pattern : public PatternBase { + Embedding1Eltwise1Pattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "embedding1_eltwise1") {} + void operator()(); + PATTERN_DECL_NODE(lookup_table1_x); + PATTERN_DECL_NODE(lookup_table1_w); + PATTERN_DECL_NODE(lookup_table1); + PATTERN_DECL_NODE(lookup_table1_out); + PATTERN_DECL_NODE(eltwise_add_in); PATTERN_DECL_NODE(eltwise_add); PATTERN_DECL_NODE(eltwise_add_out); +}; +// detect end pattern +// +// elementwise_add +// | +// elt_out_var +// scale | bias +// \ | / +// layer_norm +// +struct SkipLayerNorm : public PatternBase { + SkipLayerNorm(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "skip_layernorm") {} + void operator()(); + PATTERN_DECL_NODE(eltwise_add); + PATTERN_DECL_NODE(eltwise_add_out); PATTERN_DECL_NODE(layer_norm); PATTERN_DECL_NODE(layer_norm_bias); PATTERN_DECL_NODE(layer_norm_scale); @@ -79,6 +121,23 @@ struct EmbeddingEltwiseLayerNormPattern : public PatternBase { // // (word, pos, sent, weights_0, weights_1, weights_2, // scale, baias) embedding_eltwise_layernorm -> layer_norm_out +// +// +// in_var emb_var in_var emb_var in_var emb_var in_var emb_var +// | | | | | | | | +// lookup_table lookup_table lookup_table ... lookup_table +// | | | | +// lkt_var lkt_var lkt_var lkt_var +// \ / | ... | +// elementwise_add | | +// \ / | +// elementwise_add | +// | | +// elt_var / +// \ / +// elementwise_add +// | +// layer_norm class EmbeddingEltwiseLayerNormFusePass : public FusePassBase { public: diff --git a/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass_tester.cc b/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..71c9dbae1a46af1ecae0aaff3fde52de8142d4bb --- /dev/null +++ b/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass_tester.cc @@ -0,0 +1,98 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.h" + +#include +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(SkipLayerNormFusePass, basic) { + // inputs operator output + // -------------------------------------------------------------------- + // (x, y) elementwise_add -> elementwise_out + // (elementwise_out, scale, bias) layer_norm -> layer_norm_out... + Layers layers; + auto* x0 = layers.data("x0", {1, 256, 1}); + auto* x1 = layers.data("x1", {1, 256, 1}); + auto* x2 = layers.data("x2", {1, 256, 1}); + auto* x3 = layers.data("x3", {1, 256, 1}); + + auto* emb0 = layers.data("emb0", {18000, 768}, true); + auto* emb1 = layers.data("emb1", {4, 768}, true); + auto* emb2 = layers.data("emb2", {513, 768}, true); + auto* emb3 = layers.data("emb3", {3, 768}, true); + + auto* lkt0 = layers.embedding(x0, emb0); + auto* lkt1 = layers.embedding(x1, emb1); + auto* lkt2 = layers.embedding(x2, emb2); + auto* lkt3 = layers.embedding(x3, emb3); + + auto* elementwise_out1 = layers.elementwise_add(lkt0, lkt2); + auto* elementwise_out2 = layers.elementwise_add(elementwise_out1, lkt1); + auto* elementwise_out3 = layers.elementwise_add(elementwise_out2, lkt3); + + auto* scale = layers.data("scale", {768}, true); + auto* bias = layers.data("bias", {768}, true); + layers.layer_norm(elementwise_out3, scale, bias); + + auto* y0 = layers.data("y0", {1, 256, 1}); + auto* y1 = layers.data("y1", {1, 256, 1}); + auto* y2 = layers.data("y2", {1, 256, 1}); + + auto* emb0y = layers.data("emb0y", {18000, 768}, true); + auto* emb1y = layers.data("emb1y", {4, 768}, true); + auto* emb2y = layers.data("emb2y", {513, 768}, true); + + auto* lkt0y = layers.embedding(y0, emb0y); + auto* lkt1y = layers.embedding(y1, emb1y); + auto* lkt2y = layers.embedding(y2, emb2y); + + auto* elementwise_out1y = layers.elementwise_add(lkt0y, lkt2y); + auto* elementwise_out2y = layers.elementwise_add(elementwise_out1y, lkt1y); + + auto* scaley = layers.data("scaley", {768}, true); + auto* biasy = layers.data("biasy", {768}, true); + layers.layer_norm(elementwise_out2y, scaley, biasy); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = + PassRegistry::Instance().Get("embedding_eltwise_layernorm_fuse_pass"); + int num_nodes_before = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + int num_fused_nodes_after = + GetNumOpNodes(graph, "fused_embedding_eltwise_layernorm"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 28, + platform::errors::PreconditionNotMet( + "The number of nodes before and after the fuse does " + "not meet expectations")); + PADDLE_ENFORCE_EQ( + num_fused_nodes_after, 2, + platform::errors::PreconditionNotMet( + "The number of fusion nodes does not meet expectations after fuse")); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(embedding_eltwise_layernorm_fuse_pass); diff --git a/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc b/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc index 7672d48c5b5db96465489414fd6585b1591d417b..81037cb3149fd334e6b681e7fac76e9571582a74 100644 --- a/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc @@ -26,33 +26,19 @@ class EmbeddingEltWiseLayerNormOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* context) const override { - PADDLE_ENFORCE_EQ(context->HasInput("WordId"), true, + PADDLE_ENFORCE_EQ(context->Inputs("Ids").size(), + context->Inputs("Embs").size(), platform::errors::InvalidArgument( - "Input(WordId) of EmbeddingEltWiseLayerNormOp should " - "not be null.")); - - PADDLE_ENFORCE_EQ( - context->HasInput("PosId"), true, - platform::errors::InvalidArgument( - "Input(PosId) of EmbeddingEltWiseLayerNormOp should not be null.")); - - PADDLE_ENFORCE_EQ(context->HasInput("SentId"), true, + "Two inputs of EmbeddingEltWiseLayerNormOp shoube be " + "the same size")); + PADDLE_ENFORCE_GE(context->Inputs("Embs").size(), 2UL, platform::errors::InvalidArgument( - "Input(SentId) of EmbeddingEltWiseLayerNormOp should " - "not be null.")); - - PADDLE_ENFORCE_EQ(context->HasInput("WordEmb"), true, - platform::errors::InvalidArgument( - "Input(WordEmb) of EmbeddingEltWiseLayerNormOp " - "should not be null.")); - PADDLE_ENFORCE_EQ(context->HasInput("PosEmb"), true, + "Input Embs of EmbeddingEltWiseLayerNormOp should " + "have at least 2 tensors")); + PADDLE_ENFORCE_GE(context->Inputs("Ids").size(), 2UL, platform::errors::InvalidArgument( - "Input(PosEmb) of EmbeddingEltWiseLayerNormOp should " - "not be null.")); - PADDLE_ENFORCE_EQ(context->HasInput("SentEmb"), true, - platform::errors::InvalidArgument( - "Input(SentEmb) of EmbeddingEltWiseLayerNormOp " - "should not be null.")); + "Input Ids of EmbeddingEltWiseLayerNormOp should " + "have at least 2 tensors")); PADDLE_ENFORCE_EQ( context->HasInput("Bias"), true, @@ -70,55 +56,55 @@ class EmbeddingEltWiseLayerNormOp : public framework::OperatorWithKernel { "Output(Out) of EmbeddingEltWiseLayerNormOp should not be null.")); // batch * seq_len * 1 - auto dims_word_id = context->GetInputDim("WordId"); + auto ids_dims = context->GetInputsDim("Ids"); // word_num * hidden - auto dims_word_emb = context->GetInputDim("WordEmb"); - auto dims_pos_emb = context->GetInputDim("PosEmb"); - auto dims_sent_emb = context->GetInputDim("SentEmb"); + auto embs_dims = context->GetInputsDim("Embs"); // hidden auto dims_bias = context->GetInputDim("Bias"); - PADDLE_ENFORCE_EQ( - dims_word_emb[1], dims_bias[0], - platform::errors::InvalidArgument( - "The second dims (%d) of the Word Embedding should be equal " - "to the Bias's size(%d).", - dims_word_emb[1], dims_bias[0])); - PADDLE_ENFORCE_EQ(dims_word_emb.size(), 2, - platform::errors::InvalidArgument( - "The WordEmb dim's size shoule be 2, but found %d.", - dims_word_emb.size())); - PADDLE_ENFORCE_EQ(dims_pos_emb.size(), 2, - platform::errors::InvalidArgument( - "The PosEmb dim's size shoule be 2, but found %d.", - dims_pos_emb.size())); - PADDLE_ENFORCE_EQ(dims_sent_emb.size(), 2, - platform::errors::InvalidArgument( - "The SentEmb dim's size shoule be 2, but found %d.", - dims_sent_emb.size())); - PADDLE_ENFORCE_EQ( - dims_word_emb[1], dims_pos_emb[1], - platform::errors::InvalidArgument( - "The WordEmb first dim size(%d) shoule equal to PosEmb ones(%d).", - dims_word_emb[1], dims_pos_emb[1])); - PADDLE_ENFORCE_EQ( - dims_word_emb[1], dims_sent_emb[1], - platform::errors::InvalidArgument( - "The WordEmb first dim size(%d) shoule equal to SentEmb ones(%d).", - dims_word_emb[1], dims_sent_emb[1])); + int batch = ids_dims[0][0]; + int seq_len = ids_dims[0][1]; + int hidden = embs_dims[0][1]; + for (size_t i = 0; i < embs_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(embs_dims[i].size(), 2, + platform::errors::InvalidArgument( + "The Emb dim's size shoule be 2, but found %d.", + embs_dims[i].size())); + PADDLE_ENFORCE_EQ( + embs_dims[i][1], dims_bias[0], + platform::errors::InvalidArgument( + "The second dims (%d) of the Embedding should be equal " + "to the Bias's size(%d).", + embs_dims[i][1], dims_bias[0])); + PADDLE_ENFORCE_EQ( + embs_dims[i][1], hidden, + platform::errors::InvalidArgument( + "The Emb first dim size(%d) shoule equal to hidden (%d).", + embs_dims[i][1], hidden)); + } - int batch = dims_word_id[0]; - int seq_len = dims_word_id[1]; - int hidden = dims_word_emb[1]; auto dim_output = framework::make_ddim({batch, seq_len, hidden}); context->SetOutputDim("Out", dim_output); - context->ShareLoD("WordId", /*->*/ "Out"); + context->ShareLoD("Ids", /*->*/ "Out"); } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "WordEmb"); - return framework::OpKernelType(data_type, ctx.device_context()); + auto inputs = ctx.MultiInput("Embs"); + auto input_data_type = framework::proto::VarType::Type(0); + bool flag = 0; + for (auto* input : inputs) { + if (input->IsInitialized() && input->numel() > 0) { + input_data_type = input->type(); + flag = 1; + break; + } + } + if (flag == 0) { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "All Inputs of fused_embedding_eltwise_layernorm OP are Empty!")); + } + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -126,15 +112,10 @@ class EmbeddingEltWiseLayerNormOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("WordId", "The word id input of EmbeddingEltWiseLayerNorm op"); - AddInput("PosId", "The position id input of EmbeddingEltWiseLayerNorm op"); - AddInput("SentId", "The sentence id input of EmbeddingEltWiseLayerNorm op"); - AddInput("WordEmb", - "The Word embedding input of EmbeddingEltWiseLayerNorm op"); - AddInput("PosEmb", - "The Position embedding input of EmbeddingEltWiseLayerNorm op"); - AddInput("SentEmb", - "The Sent embedding input of EmbeddingEltWiseLayerNorm op"); + AddInput("Ids", "Input id tensors of EmbeddingEltWiseLayerNorm op") + .AsDuplicable(); + AddInput("Embs", "Input emb tensors of EmbeddingEltWiseLayerNorm op") + .AsDuplicable(); AddInput("Bias", "The LayerNorm Bias of EmbeddingEltWiseLayerNorm op"); AddInput("Scale", "The LayerNorm Scale of EmbeddingEltWiseLayerNorm op"); AddOutput("Out", "The output of EmbeddingEltWiseLayerNorm op"); @@ -157,10 +138,11 @@ class EmbeddingEltWiseLayerNormOpMaker EmbeddingEltWiseLayerNorm Operator. This op is used for optimize the following structure in ernie model. -wordid -> lookup_table_op -> word -posid -> lookup_table_op -> pos -sentdid -> lookup_table_op -> sent -word + pos + sent -> Y +id1 -> lookup_table_op -> data1 +id2 -> lookup_table_op -> data2 + ... +idn -> lookup_table_op -> data_n +data1 + data2 + ... + data_n -> Y Y -> layer_norm -> Out Not suggest to use in other case except has same structure as ernie. diff --git a/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cu b/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cu index 2c0fe4963536c33c6d6e789c19dd5e5d101b3347..a6a63c5c780814c707ea9d94c777f887e2881f91 100644 --- a/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cu +++ b/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cu @@ -16,6 +16,7 @@ #include #include #include // NOLINT +#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/detail/safe_ref.h" @@ -57,32 +58,28 @@ __device__ inline void LayerNorm(const cv2 &thread_data, const int ld, } template -__global__ void EmbEltwiseLayernormKernel( - int hidden, const int64_t *word_id_d, const int64_t *pos_id_d, - const int64_t *sent_id_d, const T *scale, const T *bias, const T *word_emb, - const T *pos_emb, const T *sent_emb, T *output, float eps) { +__global__ void EmbEltwiseLayernormKernel(int hidden, const int64_t *ids, + const T *scale, const T *bias, + const int64_t *embs, T *output, + float eps, int input_num) { cub::Sum pair_sum; // blockIdx.x: position in the sequence // blockIdx.y: batch // gridDim.x: Seq // gridDim.y: Batch - __shared__ int64_t word_id; - __shared__ int64_t pos_id; - __shared__ int64_t sent_id; + + extern __shared__ int64_t array_id[]; const T rhidden = T(1.f) / T(hidden); const int64_t seq_pos = blockIdx.y + blockIdx.x * gridDim.y; if (threadIdx.x == 0) { - word_id = word_id_d[seq_pos]; - pos_id = pos_id_d[seq_pos]; - sent_id = sent_id_d[seq_pos]; + for (int i = 0; i < input_num; ++i) { + const int64_t *ids_p = reinterpret_cast(ids[i]); + array_id[i] = ids_p[seq_pos]; + } } __syncthreads(); - // load word, pos, sentence embeddings and add them toghether - const int64_t woffset = word_id * hidden; - const int64_t poffset = pos_id * hidden; - const int64_t soffset = sent_id * hidden; const int64_t out_offset = seq_pos * hidden; cv2 thread_data; @@ -91,10 +88,10 @@ __global__ void EmbEltwiseLayernormKernel( #pragma unroll for (int it = threadIdx.x; it < hidden; it += TPB) { - const T w(word_emb[woffset + it]); - const T p(pos_emb[poffset + it]); - const T s(sent_emb[soffset + it]); - const T val = w + s + p; + T val = 0; + for (int i = 0; i < input_num; ++i) { + val += reinterpret_cast(embs[i])[array_id[i] * hidden + it]; + } output[out_offset + it] = val; const T rhiddenval = rhidden * val; @@ -112,47 +109,58 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { using Tensor = framework::Tensor; - auto *word_id = context.Input("WordId"); - auto *pos_id = context.Input("PosId"); - auto *sent_id = context.Input("SentId"); - - auto *word_emb = context.Input("WordEmb"); - auto *pos_emb = context.Input("PosEmb"); - auto *sent_emb = context.Input("SentEmb"); + auto &device_ctx = context.template device_context(); + auto ids = context.MultiInput("Ids"); + auto embs = context.MultiInput("Embs"); + int input_num = static_cast(ids.size()); + + framework::Tensor in_ids_(framework::proto::VarType::INT64), + in_embs_(framework::proto::VarType::INT64); + framework::DDim in_dim{input_num}; + int device_id; + cudaGetDevice(&device_id); + in_ids_.Resize(in_dim); + in_embs_.Resize(in_dim); + int64_t *in_ids_d = + in_ids_.mutable_data(platform::CUDAPlace(device_id)); + int64_t *in_embs_d = + in_embs_.mutable_data(platform::CUDAPlace(device_id)); + + std::vector in1s, in2s; + for (int i = 0; i < input_num; ++i) { + in1s.push_back(reinterpret_cast(ids[i]->data())); + in2s.push_back(reinterpret_cast(embs[i]->data())); + } + + cudaMemcpyAsync(in_ids_d, in1s.data(), sizeof(int64_t) * input_num, + cudaMemcpyHostToDevice, device_ctx.stream()); + cudaMemcpyAsync(in_embs_d, in2s.data(), sizeof(int64_t) * input_num, + cudaMemcpyHostToDevice, device_ctx.stream()); auto *bias = context.Input("Bias"); auto *scale = context.Input("Scale"); auto *out = context.Output("Out"); - auto *word_id_d = word_id->data(); - auto *pos_id_d = pos_id->data(); - auto *sent_id_d = sent_id->data(); + // should be (B * S * hidden) + auto id0_dims = ids[0]->dims(); + auto emb0_dims = embs[0]->dims(); - auto *word_emb_d = word_emb->data(); - auto *pos_emb_d = pos_emb->data(); - auto *sent_emb_d = sent_emb->data(); + int batch = id0_dims[0]; + int seq_len = id0_dims[1]; + int hidden = emb0_dims[1]; auto *bias_d = bias->data(); auto *scale_d = scale->data(); auto *output_d = out->mutable_data(context.GetPlace()); - // compute q*k with eltadd - auto &device_ctx = context.template device_context(); float eps = context.Attr("epsilon"); - // should be (B * S * hidden) - auto word_id_dims = word_id->dims(); - auto word_emb_dims = word_emb->dims(); - - int batch = word_id_dims[0]; - int seq_len = word_id_dims[1]; - int hidden = word_emb_dims[1]; - const unsigned tpb = 256; const dim3 grid(seq_len, batch, 1); const dim3 block(tpb, 1, 1); - EmbEltwiseLayernormKernel<<>>( - hidden, word_id_d, pos_id_d, sent_id_d, scale_d, bias_d, word_emb_d, - pos_emb_d, sent_emb_d, output_d, eps); + int shared_bytes = input_num * sizeof(int64_t); + EmbEltwiseLayernormKernel< + T, tpb><<>>( + hidden, in_ids_d, scale_d, bias_d, in_embs_d, output_d, eps, input_num); } }; diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_embedding_eltwise_layernorm_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_embedding_eltwise_layernorm_fuse_pass.py index 7680b4c7a0b6365d1260ec0c9b6fa9d7ed21007d..aa31bc2a35d5592809e832115aaa907072fcc87b 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_embedding_eltwise_layernorm_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_embedding_eltwise_layernorm_fuse_pass.py @@ -48,6 +48,39 @@ class EmbEltwiseLayerNormFusePassTest(PassTest): add2 = fluid.layers.elementwise_add(add1, sent_emb) hidden1 = fluid.layers.layer_norm(input=add2, begin_norm_axis=2) + id1 = fluid.layers.data( + name="id1", + shape=[1, 128, 1], + dtype="int64", + append_batch_size=False) + id2 = fluid.layers.data( + name="id2", + shape=[1, 128, 1], + dtype="int64", + append_batch_size=False) + id3 = fluid.layers.data( + name="id3", + shape=[1, 128, 1], + dtype="int64", + append_batch_size=False) + id4 = fluid.layers.data( + name="id4", + shape=[1, 128, 1], + dtype="int64", + append_batch_size=False) + emb1 = fluid.layers.embedding( + input=id1, size=(128, 768), dtype='float32') + emb2 = fluid.layers.embedding( + input=id2, size=(128, 768), dtype='float32') + emb3 = fluid.layers.embedding( + input=id3, size=(128, 768), dtype='float32') + emb4 = fluid.layers.embedding( + input=id4, size=(128, 768), dtype='float32') + add_1 = fluid.layers.elementwise_add(emb1, emb2) + add_2 = fluid.layers.elementwise_add(add_1, emb3) + add_3 = fluid.layers.elementwise_add(add_2, emb4) + hidden_1 = fluid.layers.layer_norm(input=add_3, begin_norm_axis=2) + self.feeds = { "word_id": np.random.randint( low=0, high=128, size=(1, 128, 1)).astype("int64"), @@ -55,11 +88,19 @@ class EmbEltwiseLayerNormFusePassTest(PassTest): low=0, high=128, size=(1, 128, 1)).astype("int64"), "sent_id": np.random.randint( low=0, high=128, size=(1, 128, 1)).astype("int64"), + "id1": np.random.randint( + low=0, high=128, size=(1, 128, 1)).astype("int64"), + "id2": np.random.randint( + low=0, high=128, size=(1, 128, 1)).astype("int64"), + "id3": np.random.randint( + low=0, high=128, size=(1, 128, 1)).astype("int64"), + "id4": np.random.randint( + low=0, high=128, size=(1, 128, 1)).astype("int64"), } - self.fetch_list = [hidden1] + self.fetch_list = [hidden1, hidden_1] self.pass_names = "embedding_eltwise_layernorm_fuse_pass" self.fused_op_type = "fused_embedding_eltwise_layernorm" - self.num_fused_ops = 1 + self.num_fused_ops = 2 def test_check_output(self): use_gpu_set = [True]