From 7b7e605189779e93173b650662cabb06c2d697d7 Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Tue, 7 Jul 2020 13:44:10 +0800 Subject: [PATCH] [Fix BUGs]: fix multhead matmul pass's instable bug (#25123) * fix multhead matmul's instable test=develop * fix multihead matmul bug test=develop * fix converage problem test=develop --- .../framework/ir/graph_pattern_detector.cc | 10 +- .../framework/ir/graph_pattern_detector.h | 4 +- .../ir/multihead_matmul_fuse_pass.cc | 135 ++++++++---------- .../framework/ir/multihead_matmul_fuse_pass.h | 5 +- .../tensorrt/plugin/qkv_to_context_plugin.cu | 2 +- 5 files changed, 72 insertions(+), 84 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index f58e6c8bff9..ff6dffa704e 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -140,7 +140,7 @@ void GraphPatternDetector::ValidateByNodeRole( subgraphs->begin(), subgraphs->end(), [](const GraphPatternDetector::subgraph_t &subgraph) -> bool { // Collect the inputs and outputs. - std::unordered_set ios; + std::set ios; for (auto &item : subgraph) { if (!item.first->IsIntermediate()) { ios.insert(item.second); @@ -166,7 +166,7 @@ void GraphPatternDetector::ValidateByNodeRole( } struct HitGroup { - std::unordered_map roles; + std::map roles; bool Match(Node *node, PDNode *pat) { if (nodes_.count(node)) { @@ -184,7 +184,7 @@ struct HitGroup { } private: - std::unordered_set nodes_; + std::set nodes_; }; // Tell whether Node a links to b. @@ -283,7 +283,7 @@ void GraphPatternDetector::UniquePatterns( if (subgraphs->empty()) return; std::vector result; - std::unordered_set set; + std::set set; std::hash hasher; for (auto &g : *subgraphs) { // Sort the items in the sub-graph, and transform to a string key. @@ -305,7 +305,7 @@ void GraphPatternDetector::UniquePatterns( void GraphPatternDetector::RemoveOverlappedMatch( std::vector *subgraphs) { std::vector result; - std::unordered_set node_set; + std::set node_set; for (const auto &subgraph : *subgraphs) { bool valid = true; diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 422ad1ef47a..e1cce7848dd 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -231,7 +231,7 @@ class PDPattern { std::vector> nodes_; std::vector edges_; - std::unordered_map node_map_; + std::map node_map_; static size_t id_; }; @@ -263,7 +263,7 @@ class PDPattern { */ class GraphPatternDetector { public: - using subgraph_t = std::unordered_map; + using subgraph_t = std::map; // Operate on the detected pattern. using handle_t = diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index 85d20a7b9a2..40e01c75bb9 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -45,13 +45,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { // Create pattern. MultiHeadMatmulPattern multihead_pattern(pattern, name_scope); - PDNode* x = - pattern->NewNode(patterns::UniqueKey("X"))->assert_var_not_persistable(); - - multihead_pattern(x); + multihead_pattern(); // Create New OpDesc auto fuse_creater = [&]( - Node* x, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out, + Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out, Node* mul1_out, Node* mul2_out, Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b, Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out) { @@ -115,7 +112,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { // GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); @@ -185,7 +182,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, multihead_pattern); - fuse_creater(layer_norm, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, + fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, reshape2_0, reshape2_qkv_out, scale, scale_out); @@ -232,12 +229,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { return fusion_count; } -PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) { - // Create shared nodes. - auto* layer_norm = pattern->NewNode(layer_norm_repr()); - - auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()); - layer_norm_out_var->assert_is_op_input("mul"); +PDNode* MultiHeadMatmulPattern::operator()() { + auto* input0 = pattern->NewNode(input0_repr()); + input0->assert_is_op_input("mul"); // First path with scale auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul"); @@ -390,17 +384,15 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) { transpose2_2_out_var->AsIntermediate()->assert_is_op_input( "matmul"); // link to matmul qkv - // Link all nodes together - layer_norm->LinksFrom({x}).LinksTo({layer_norm_out_var}); // Q path - mul0->LinksFrom({layer_norm_out_var, mul0_w_var}).LinksTo({mul0_out_var}); + mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var}); eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var}); reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); scale->LinksFrom({transpose2_0_out_var}).LinksTo({scale_out_var}); // K path - mul1->LinksFrom({layer_norm_out_var, mul1_w_var}).LinksTo({mul1_out_var}); + mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var}); eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var}); reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var}); transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var}); @@ -411,7 +403,7 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) { .LinksTo({eltadd_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); // V path - mul2->LinksFrom({layer_norm_out_var, mul2_w_var}).LinksTo({mul2_out_var}); + mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var}); eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var}); reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var}); transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var}); @@ -434,13 +426,10 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, // Create pattern. MultiHeadMatmulPattern multihead_pattern(pattern, name_scope); - PDNode* x = - pattern->NewNode(patterns::UniqueKey("X"))->assert_var_not_persistable(); - - multihead_pattern(x); + multihead_pattern(); // Create New OpDesc auto fuse_creater = [&]( - Node* layer_norm_out, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out, + Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out, Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w, Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b, Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out) { @@ -471,29 +460,20 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, framework::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); auto combined_bias_dims = framework::make_ddim({3, bq_tensor->dims()[0]}); - // create a new var in scope - VarDesc combined_w_desc( - patterns::PDNodeName(name_scope, "multi_head_combined_weight")); - combined_w_desc.SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); - combined_w_desc.SetDataType(wq_tensor->type()); - combined_w_desc.SetLoDLevel(mul0_w->Var()->GetLoDLevel()); - combined_w_desc.SetPersistable(true); - - // create a new var in scope - VarDesc combined_bias_desc( - patterns::PDNodeName(name_scope, "multi_head_combined_bias")); - combined_bias_desc.SetShape({3, bq_tensor->dims()[0]}); - combined_bias_desc.SetDataType(bq_tensor->type()); - combined_bias_desc.SetLoDLevel(eltadd0_b->Var()->GetLoDLevel()); - combined_bias_desc.SetPersistable(true); - - auto* combined_w_node = graph->CreateVarNode(&combined_w_desc); - auto* combined_w_tensor = - scope->Var(combined_w_node->Name())->GetMutable(); - - combined_w_tensor->Resize(combined_w_dims); - auto* combined_w_data = - combined_w_tensor->mutable_data(platform::CPUPlace()); + // reuse the mul0_w and eltadd_0_b nodes for the combined nodes. + auto* combined_w_desc = mul0_w->Var(); + combined_w_desc->SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); + combined_w_desc->SetPersistable(true); + + auto* combined_bias_desc = eltadd0_b->Var(); + combined_bias_desc->SetShape({3, bq_tensor->dims()[0]}); + combined_bias_desc->SetPersistable(true); + + framework::LoDTensor tmp_combined_w_tensor; + tmp_combined_w_tensor.Resize(combined_w_dims); + auto* tmp_combined_w_data = + tmp_combined_w_tensor.mutable_data(platform::CPUPlace()); + std::vector w_vec = {wq_data, wk_data, wv_data}; int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2]; // Combine the three fc weights together. @@ -502,25 +482,38 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, for (int k = 0; k < dims_w; k++) { int out_index = i * (3 * dims_w) + j * dims_w + k; int in_index = i * dims_w + k; - combined_w_data[out_index] = w_vec[j][in_index]; + tmp_combined_w_data[out_index] = w_vec[j][in_index]; } } } - scope->EraseVars({mul0_w->Name(), mul1_w->Name(), mul2_w->Name()}); - auto* combined_bias_node = graph->CreateVarNode(&combined_bias_desc); - auto* combined_bias_tensor = - scope->Var(combined_bias_node->Name())->GetMutable(); - - combined_bias_tensor->Resize(combined_bias_dims); - auto* combined_bias_data = - combined_bias_tensor->mutable_data(platform::CPUPlace()); + + wq_tensor->Resize(combined_w_dims); + auto* new_combined_w_data = + wq_tensor->mutable_data(platform::CPUPlace()); + memcpy(new_combined_w_data, tmp_combined_w_data, + sizeof(float) * wq_tensor->numel()); + + scope->EraseVars({mul1_w->Name(), mul2_w->Name()}); + + framework::LoDTensor tmp_combined_bias_tensor; + tmp_combined_bias_tensor.Resize(combined_bias_dims); + auto* tmp_combined_bias_data = + tmp_combined_bias_tensor.mutable_data(platform::CPUPlace()); + size_t bias_size = bq_tensor->numel(); - memcpy(combined_bias_data, bq_data, sizeof(float) * bias_size); - memcpy(combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size); - memcpy(combined_bias_data + 2 * bias_size, bv_data, + memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size); + memcpy(tmp_combined_bias_data + bias_size, bk_data, + sizeof(float) * bias_size); + memcpy(tmp_combined_bias_data + 2 * bias_size, bv_data, sizeof(float) * bias_size); - scope->EraseVars({eltadd0_b->Name(), eltadd1_b->Name(), eltadd2_b->Name()}); + bq_tensor->Resize(combined_bias_dims); + auto* new_combined_bias_data = + bq_tensor->mutable_data(platform::CPUPlace()); + memcpy(new_combined_bias_data, tmp_combined_bias_data, + sizeof(float) * bq_tensor->numel()); + + scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()}); auto reshape_desc = reshape2->Op(); int head_number = @@ -529,9 +522,9 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, OpDesc multihead_op_desc; multihead_op_desc.SetType("multihead_matmul"); - multihead_op_desc.SetInput("Input", {layer_norm_out->Name()}); - multihead_op_desc.SetInput("W", {combined_w_node->Name()}); - multihead_op_desc.SetInput("Bias", {combined_bias_node->Name()}); + multihead_op_desc.SetInput("Input", {input0->Name()}); + multihead_op_desc.SetInput("W", {mul0_w->Name()}); + multihead_op_desc.SetInput("Bias", {eltadd0_b->Name()}); multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()}); multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()}); @@ -540,9 +533,9 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, auto* multihead = graph->CreateOpNode(&multihead_op_desc); - IR_NODE_LINK_TO(layer_norm_out, multihead); - IR_NODE_LINK_TO(combined_w_node, multihead); - IR_NODE_LINK_TO(combined_bias_node, multihead); + IR_NODE_LINK_TO(input0, multihead); + IR_NODE_LINK_TO(mul0_w, multihead); + IR_NODE_LINK_TO(eltadd0_b, multihead); IR_NODE_LINK_TO(eltadd_qk_b, multihead); IR_NODE_LINK_TO(multihead, reshape2_qkv_out); @@ -552,9 +545,7 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { // GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, multihead_pattern); - GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, - multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); @@ -624,14 +615,13 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, multihead_pattern); - fuse_creater(layer_norm_out, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, - mul0_w, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, - eltadd_qk_b, reshape2_0, reshape2_qkv_out, scale, scale_out); + fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w, + mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, + reshape2_0, reshape2_qkv_out, scale, scale_out); std::unordered_set marked_nodes({eltadd0, eltadd1, eltadd2, - eltadd0_b, eltadd1_b, eltadd2_b, eltadd0_out, @@ -665,7 +655,6 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, mul0_out, mul1_out, mul2_out, - mul0_w, mul1_w, mul2_w, reshape2_qkv, diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h index d6299c39c73..0afa00fc62a 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h @@ -29,11 +29,10 @@ struct MultiHeadMatmulPattern : public PatternBase { MultiHeadMatmulPattern(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "multihead_matmul") {} - PDNode* operator()(PDNode* x); + PDNode* operator()(); // declare operator node's name - PATTERN_DECL_NODE(layer_norm); - PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(input0); PATTERN_DECL_NODE(mul0); PATTERN_DECL_NODE(mul1); PATTERN_DECL_NODE(mul2); diff --git a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu index 0f9c94a0afb..fe3ea180593 100644 --- a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu @@ -167,7 +167,7 @@ nvinfer1::DimsExprs QkvToContextPluginDynamic::getOutputDimensions( ret.nbDims = 5; ret.d[0] = inputs[0].d[0]; ret.d[1] = inputs[0].d[1]; - ret.d[2] = expr_builder.constant(hidden_); + ret.d[2] = expr_builder.constant(head_size_ * head_number_); ret.d[3] = expr_builder.constant(1); ret.d[4] = expr_builder.constant(1); return ret; -- GitLab