From b8333edef6e1e7eb1a0c22121d375c9660d91e61 Mon Sep 17 00:00:00 2001 From: zhaoyuchen2018 <45989343+zhaoyuchen2018@users.noreply.github.com> Date: Sun, 13 Oct 2019 00:40:16 -0500 Subject: [PATCH] Add Multihead matmul fuse pass (#20167) * Add multihead fuse pass for ernie opt * Refine softmax test=develop * Refine cuda kernel * Refine cuda version * Refine cmake test=develop * refine header file * refine test case and pass * refine comments --- paddle/fluid/framework/ir/CMakeLists.txt | 2 + .../ir/multihead_matmul_fuse_pass.cc | 443 ++++++++++++++++++ .../framework/ir/multihead_matmul_fuse_pass.h | 103 ++++ .../ir/multihead_matmul_fuse_pass_tester.cc | 107 +++++ .../fluid/framework/ir/pass_tester_helper.h | 52 ++ .../inference/api/paddle_pass_builder.cc | 5 +- 6 files changed, 710 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h create mode 100755 paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 6db8487d67e..9232c99938d 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -74,6 +74,7 @@ pass_library(shuffle_channel_detect_pass inference) pass_library(delete_quant_dequant_op_pass inference) pass_library(simplify_with_basic_ops_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base) +pass_library(multihead_matmul_fuse_pass inference) if(WITH_GPU) pass_library(cudnn_placement_pass base DEPS placement_pass_base) endif() @@ -126,6 +127,7 @@ cc_test(test_repeated_fc_relu_fuse_pass SRCS repeated_fc_relu_fuse_pass_tester.c cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) cc_test(test_simplify_with_basic_ops_pass SRCS simplify_with_basic_ops_pass_tester.cc DEPS simplify_with_basic_ops_pass) cc_test(test_fc_elementwise_layernorm_fuse_pass SRCS fc_elementwise_layernorm_fuse_pass_tester.cc DEPS fc_elementwise_layernorm_fuse_pass) +cc_test(test_multihead_matmul_fuse_pass SRCS multihead_matmul_fuse_pass_tester.cc DEPS multihead_matmul_fuse_pass) cc_test(test_conv_bn_fuse_pass SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_fuse_pass) if(WITH_GPU) cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass) diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc new file mode 100644 index 00000000000..6236c16d785 --- /dev/null +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -0,0 +1,443 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h" +#include +#include +#include +#include +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/lod_tensor.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +static void ReplaceOutputVar(Node* op, Node* old_var, Node* new_var) { + if (op->IsOp() && op->Op()) { + new_var->inputs.push_back(op); + for (size_t i = 0; i < op->outputs.size(); ++i) { + if (op->outputs[i] == old_var) { + op->outputs[i] = new_var; + op->Op()->RenameOutput(old_var->Name(), new_var->Name()); + } + } + } +} + +static int BuildFusion(Graph* graph, const std::string& name_scope) { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + // Create pattern. + MultiHeadMatmulPattern multihead_pattern(pattern, name_scope); + + PDNode* x = + pattern->NewNode(patterns::UniqueKey("X"))->assert_var_not_persistable(); + + multihead_pattern(x); + // Create New OpDesc + auto fuse_creater = [&]( + Node* x, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out, + Node* mul1_out, Node* mul2_out, Node* eltadd0_b, Node* eltadd1_b, + Node* eltadd2_b, Node* eltadd_qk_b, Node* reshape2, + Node* reshape2_qkv_out, Node* scale, Node* scale_out) { + auto scale_attr = boost::get(scale->Op()->GetAttr("scale")); + // auto scale_bias = boost::get(scale->Op()->GetAttr("bias")); + // bool after_scale = + // boost::get(scale->Op()->GetAttr("bias_after_scale")); + + // create multihead + OpDesc multihead_op_desc; + + // create tmp tensor + VarDesc k_var_desc(*mul1_out->Var()); + k_var_desc.SetName("K" + mul1_out->Name()); + auto* k_var_node = graph->CreateVarNode(&k_var_desc); + + VarDesc q_var_desc(*mul0_out->Var()); + q_var_desc.SetName("Q" + mul0_out->Name()); + auto* q_var_node = graph->CreateVarNode(&q_var_desc); + + VarDesc v_var_desc(*mul2_out->Var()); + v_var_desc.SetName("V" + mul2_out->Name()); + auto* v_var_node = graph->CreateVarNode(&v_var_desc); + + auto reshape_desc = reshape2->Op(); + int head_number = + boost::get>(reshape_desc->GetAttr("shape")).at(2); + + ReplaceOutputVar(mul0, mul0_out, q_var_node); + ReplaceOutputVar(mul1, mul1_out, k_var_node); + ReplaceOutputVar(mul2, mul2_out, v_var_node); + + multihead_op_desc.SetType("multihead_matmul"); + multihead_op_desc.SetInput("Q", {q_var_node->Name()}); + multihead_op_desc.SetInput("K", {k_var_node->Name()}); + multihead_op_desc.SetInput("V", {v_var_node->Name()}); + + multihead_op_desc.SetInput("BiasQ", {eltadd0_b->Name()}); + multihead_op_desc.SetInput("BiasK", {eltadd1_b->Name()}); + multihead_op_desc.SetInput("BiasV", {eltadd2_b->Name()}); + multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()}); + + multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()}); + multihead_op_desc.SetAttr("alpha", scale_attr); + multihead_op_desc.SetAttr("head_number", head_number); + + auto* multihead = graph->CreateOpNode(&multihead_op_desc); + IR_NODE_LINK_TO(q_var_node, multihead); + IR_NODE_LINK_TO(k_var_node, multihead); + IR_NODE_LINK_TO(v_var_node, multihead); + + IR_NODE_LINK_TO(eltadd0_b, multihead); + IR_NODE_LINK_TO(eltadd1_b, multihead); + IR_NODE_LINK_TO(eltadd2_b, multihead); + IR_NODE_LINK_TO(eltadd_qk_b, multihead); + + IR_NODE_LINK_TO(multihead, reshape2_qkv_out); + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + // GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out, + multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out, + multihead_pattern); + + // nodes need be removed + GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd0_b, eltadd0_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd0_out, eltadd0_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd1, eltadd1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd1_b, eltadd1_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd1_out, eltadd1_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd2, eltadd2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd2_b, eltadd2_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd2_out, eltadd2_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk, matmul_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk_out, matmul_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk, eltadd_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_b, eltadd_qk_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out, + multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out, + multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, + multihead_pattern); + + fuse_creater(layer_norm, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, + eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, reshape2_0, + reshape2_qkv_out, scale, scale_out); + + std::unordered_set marked_nodes( + {eltadd0, + eltadd1, + eltadd2, + eltadd0_out, + eltadd1_out, + eltadd2_out, + reshape2_0, + reshape2_1, + reshape2_2, + reshape2_0_out, + reshape2_1_out, + reshape2_2_out, + transpose2_0, + transpose2_1, + transpose2_2, + transpose2_0_out, + transpose2_1_out, + transpose2_2_out, + matmul_qk, + matmul_qk_out, + eltadd_qk, + eltadd_qk_out, + softmax_qk, + softmax_qk_out, // dropout_qk, dropout_qk_out, + transpose2_qkv, + transpose2_qkv_out, + matmul_qkv, + matmul_qkv_out, + mul0_out, + mul1_out, + mul2_out, + reshape2_qkv, + scale}); + // Remove unneeded nodes. + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + gpd(graph, handler); + + return fusion_count; +} + +PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) { + // Create shared nodes. + auto* layer_norm = pattern->NewNode(layer_norm_repr()); + + auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()); + layer_norm_out_var->assert_is_op_input("mul"); + + // First path with scale + auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul"); + auto* mul0_w_var = pattern->NewNode(mul0_w_repr()) + ->AsInput() + ->assert_is_op_input("mul", "Y"); + auto* mul0_out_var = + pattern->NewNode(mul0_out_repr())->assert_is_op_output("mul"); + + decltype(mul0) eltadd0; + decltype(mul0) eltadd0_b_var; + decltype(mul0) eltadd0_out_var; + + mul0_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + eltadd0 = pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add"); + eltadd0_b_var = pattern->NewNode(eltadd0_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + + eltadd0_out_var = pattern->NewNode(eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd0_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_0 = + pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2"); + + auto* reshape2_0_out_var = + pattern->NewNode(reshape2_0_out_repr())->assert_is_op_output("reshape2"); + reshape2_0_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_0 = + pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2"); + auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_0_out_var->AsIntermediate()->assert_is_op_input("scale"); + + auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale"); + auto* scale_out_var = + pattern->NewNode(scale_out_repr())->assert_is_op_output("scale"); + scale_out_var->AsIntermediate()->assert_is_op_input("matmul"); + + auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul"); + auto* matmul_qk_out_var = + pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul"); + matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + auto* eltadd_qk = + pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add"); + auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd_qk_out_var->AsIntermediate()->assert_is_op_input("softmax"); + + auto* softmax_qk = + pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax"); + auto* softmax_qk_out_var = + pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax"); + softmax_qk_out_var->AsIntermediate()->assert_is_op_input("matmul"); + + auto* matmul_qkv = + pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul"); + auto* matmul_qkv_out_var = + pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul"); + matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_qkv = + pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2"); + auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_qkv = + pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2"); + auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr()) + ->assert_is_op_output("reshape2"); + reshape2_qkv_out_var->assert_is_op_input("mul"); + + // Second path to matmul + auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("mul"); + auto* mul1_w_var = pattern->NewNode(mul1_w_repr()) + ->AsInput() + ->assert_is_op_input("mul", "Y"); + auto* mul1_out_var = + pattern->NewNode(mul1_out_repr())->assert_is_op_output("mul"); + + decltype(mul1) eltadd1; + decltype(mul1) eltadd1_b_var; + decltype(mul1) eltadd1_out_var; + + mul1_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + eltadd1 = pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add"); + eltadd1_b_var = pattern->NewNode(eltadd1_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + + eltadd1_out_var = pattern->NewNode(eltadd1_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd1_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_1 = + pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2"); + + auto* reshape2_1_out_var = + pattern->NewNode(reshape2_1_out_repr())->assert_is_op_output("reshape2"); + reshape2_1_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_1 = + pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2"); + auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_1_out_var->AsIntermediate()->assert_is_op_input( + "matmul"); // link to matmul qk + + // Third path to matmul + auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("mul"); + auto* mul2_w_var = pattern->NewNode(mul2_w_repr()) + ->AsInput() + ->assert_is_op_input("mul", "Y"); + auto* mul2_out_var = + pattern->NewNode(mul2_out_repr())->assert_is_op_output("mul"); + + decltype(mul2) eltadd2; + decltype(mul2) eltadd2_b_var; + decltype(mul2) eltadd2_out_var; + + mul2_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + eltadd2 = pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add"); + eltadd2_b_var = pattern->NewNode(eltadd2_b_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + + eltadd2_out_var = pattern->NewNode(eltadd2_out_repr()) + ->assert_is_op_output("elementwise_add"); + eltadd2_out_var->AsIntermediate()->assert_is_op_input("reshape2"); + + auto* reshape2_2 = + pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2"); + + auto* reshape2_2_out_var = + pattern->NewNode(reshape2_2_out_repr())->assert_is_op_output("reshape2"); + reshape2_2_out_var->AsIntermediate()->assert_is_op_input("transpose2"); + + auto* transpose2_2 = + pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2"); + auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr()) + ->assert_is_op_output("transpose2"); + transpose2_2_out_var->AsIntermediate()->assert_is_op_input( + "matmul"); // link to matmul qkv + + // Link all nodes together + layer_norm->LinksFrom({x}).LinksTo({layer_norm_out_var}); + // Q path + mul0->LinksFrom({layer_norm_out_var, mul0_w_var}).LinksTo({mul0_out_var}); + eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var}); + + reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); + transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); + scale->LinksFrom({transpose2_0_out_var}).LinksTo({scale_out_var}); + // K path + mul1->LinksFrom({layer_norm_out_var, mul1_w_var}).LinksTo({mul1_out_var}); + eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var}); + reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var}); + transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var}); + // compute q*k + matmul_qk->LinksFrom({scale_out_var, transpose2_1_out_var}) + .LinksTo({matmul_qk_out_var}); + eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) + .LinksTo({eltadd_qk_out_var}); + softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); + // V path + mul2->LinksFrom({layer_norm_out_var, mul2_w_var}).LinksTo({mul2_out_var}); + eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var}); + reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var}); + transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var}); + // compute q*k*v + matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var}) + .LinksTo({matmul_qkv_out_var}); + transpose2_qkv->LinksFrom({matmul_qkv_out_var}) + .LinksTo({transpose2_qkv_out_var}); + reshape2_qkv->LinksFrom({transpose2_qkv_out_var}) + .LinksTo({reshape2_qkv_out_var}); + + return transpose2_2_out_var; +} + +} // namespace patterns + +void MultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL(graph); + FusePassBase::Init(name_scope_, graph); + + int fusion_count = patterns::BuildFusion(graph, name_scope_); + AddStatis(fusion_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(multihead_matmul_fuse_pass, + paddle::framework::ir::MultiHeadMatmulFusePass); diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h new file mode 100644 index 00000000000..ab58d9468e5 --- /dev/null +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h @@ -0,0 +1,103 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct MultiHeadMatmulPattern : public PatternBase { + MultiHeadMatmulPattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "multihead_matmul") {} + + PDNode* operator()(PDNode* x); + + // declare operator node's name + // PATTERN_DECL_NODE(dropout); + // PATTERN_DECL_NODE(dropout_out); + PATTERN_DECL_NODE(layer_norm); + PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(mul0); + PATTERN_DECL_NODE(mul1); + PATTERN_DECL_NODE(mul2); + PATTERN_DECL_NODE(mul0_w); + PATTERN_DECL_NODE(mul1_w); + PATTERN_DECL_NODE(mul2_w); + PATTERN_DECL_NODE(mul0_out); + PATTERN_DECL_NODE(mul1_out); + PATTERN_DECL_NODE(mul2_out); + PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD + PATTERN_DECL_NODE(eltadd0_out); + PATTERN_DECL_NODE(eltadd1_out); + PATTERN_DECL_NODE(eltadd2_out); + PATTERN_DECL_NODE(reshape2_0); + PATTERN_DECL_NODE(reshape2_1); + PATTERN_DECL_NODE(reshape2_2); + PATTERN_DECL_NODE(reshape2_qkv); + PATTERN_DECL_NODE(reshape2_0_out); + PATTERN_DECL_NODE(reshape2_1_out); + PATTERN_DECL_NODE(reshape2_2_out); + PATTERN_DECL_NODE(reshape2_qkv_out); + PATTERN_DECL_NODE(transpose2_0); + PATTERN_DECL_NODE(transpose2_1); + PATTERN_DECL_NODE(transpose2_2); + PATTERN_DECL_NODE(transpose2_qkv); + PATTERN_DECL_NODE(transpose2_0_out); + PATTERN_DECL_NODE(transpose2_1_out); + PATTERN_DECL_NODE(transpose2_2_out); + PATTERN_DECL_NODE(transpose2_qkv_out); + PATTERN_DECL_NODE(scale); + PATTERN_DECL_NODE(scale_out); + PATTERN_DECL_NODE(matmul_qk); + PATTERN_DECL_NODE(matmul_qk_out); + PATTERN_DECL_NODE(eltadd_qk); + PATTERN_DECL_NODE(eltadd_qk_b); + PATTERN_DECL_NODE(eltadd_qk_out); + PATTERN_DECL_NODE(softmax_qk); + PATTERN_DECL_NODE(softmax_qk_out); + // PATTERN_DECL_NODE(dropout_qk); + // PATTERN_DECL_NODE(dropout_qk_out); + + PATTERN_DECL_NODE(matmul_qkv); + PATTERN_DECL_NODE(matmul_qkv_out); +}; +} // namespace patterns + +// The MulGRUFusePass and MulGRUFusePass will fuse to the same FusionGRU op. +class MultiHeadMatmulFusePass : public FusePassBase { + public: + virtual ~MultiHeadMatmulFusePass() {} + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{"multihead_matmul_fuse"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc new file mode 100755 index 00000000000..d0a5c8c6fe8 --- /dev/null +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc @@ -0,0 +1,107 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h" // NOLINT +#include +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(MultiHeadMatmulFusePass, basic) { + // inputs operator output + // -------------------------------------------------------------------- + // (x) layer_norm -> layer_norm_out + // (layer_norm_out, weights_0) mul -> mul_out0 + // (layer_norm_out, weights_1) mul -> mul_out1 + // (layer_norm_out, weights_2) mul -> mul_out2 + // (mul_out0, bias_0) elementweise_add -> eltadd_0 + // (mul_out1, bias_1) elementweise_add -> eltadd_1 + // (mul_out2, bias_2) elementweise_add -> eltadd_2 + // (eltadd_0) reshape2 -> reshape_0 + // (eltadd_1) reshape2 -> reshape_1 + // (eltadd_2) reshape2 -> reshape_2 + // (reshape_0) transpose2 -> transpose_0 + // (reshape_1) transpose2 -> transpose_1 + // (reshape_2) transpose2 -> transpose_2 + // (transpose_0) scale -> scale_0 + // (scale_0, transpose_1) matmul -> matmul_qk + // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk + // (eltadd_qk) softmax -> softmax_qk + // (softmax_qk, transpose_2) matmul -> matmul_qkv + // (matmul_qkv) transpose -> transpose_qkv + // (transpose_qkv) reshape -> reshape_qkv + // (reshape_qkv) mul -> mul_qkv + Layers layers; + auto* x = layers.data("x", {128, 768}); + auto out = layers.layer_norm(x); + auto* layer_out = out[0]; + + auto* weights_0 = layers.data("weights0", {768, 768}, true); + auto* weights_1 = layers.data("weights1", {768, 768}, true); + auto* weights_2 = layers.data("weights2", {768, 768}, true); + + auto* mul_out_0 = layers.mul(layer_out, weights_0); + auto* mul_out_1 = layers.mul(layer_out, weights_1); + auto* mul_out_2 = layers.mul(layer_out, weights_2); + + auto* b0 = layers.data("bias_0", {768}, true); + auto* b1 = layers.data("bias_1", {768}, true); + auto* b2 = layers.data("bias_2", {768}, true); + + auto* elementwise_out_0 = layers.elementwise_add(mul_out_0, b0); + auto* elementwise_out_1 = layers.elementwise_add(mul_out_1, b1); + auto* elementwise_out_2 = layers.elementwise_add(mul_out_2, b2); + + std::vector shape = {128, 12, 64}; + auto* reshape_0 = layers.reshape2(elementwise_out_0, shape); + auto* reshape_1 = layers.reshape2(elementwise_out_1, shape); + auto* reshape_2 = layers.reshape2(elementwise_out_2, shape); + + std::vector axis = {0, 2, 1, 3}; + auto* transpose_0 = layers.transpose2(reshape_0, axis); + auto* transpose_1 = layers.transpose2(reshape_1, axis); + auto* transpose_2 = layers.transpose2(reshape_2, axis); + + auto* scale_0 = layers.scale(transpose_0, 0.125, 0, false); + auto* matmul_qk = layers.matmul(scale_0, transpose_1); + + auto* bqk = layers.data("biasqk", {768}, true); + auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); + auto* softmax_qk = layers.softmax(elementwise_qk, -1); + + auto* matmul_qkv = layers.matmul(softmax_qk, transpose_2); + + auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}); + auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {128, 768}); + auto* weights_l = layers.data("weightsl", {768, 768}, true); + layers.mul(reshape_qkv_out, weights_l); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("multihead_matmul_fuse_pass"); + int num_nodes_before = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + int num_fused_nodes_after = GetNumOpNodes(graph, "multihead_matmul"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 29); + PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(multihead_matmul_fuse_pass); diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 970bd2d58d5..0601b8801af 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -175,6 +175,58 @@ struct Layers { return outs; } + VarDesc* matmul(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("matmul"); + op->SetInput("X", {x->Name()}); + op->SetInput("Y", {y->Name()}); + op->SetOutput("Out", {out->Name()}); + return out; + } + + VarDesc* transpose2(VarDesc* x, std::vector axis) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("transpose2"); + op->SetInput("X", {x->Name()}); + op->SetAttr("axis", axis); + op->SetOutput("Out", {out->Name()}); + return out; + } + + VarDesc* reshape2(VarDesc* x, std::vector shape) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("reshape2"); + op->SetInput("X", {x->Name()}); + op->SetAttr("shape", shape); + op->SetOutput("Out", {out->Name()}); + return out; + } + + VarDesc* softmax(VarDesc* x, int axis) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("softmax"); + op->SetInput("X", {x->Name()}); + op->SetAttr("axis", axis); + op->SetOutput("Out", {out->Name()}); + return out; + } + + VarDesc* scale(VarDesc* x, float scale, float bias, bool bias_after) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("scale"); + op->SetInput("X", {x->Name()}); + op->SetAttr("scale", scale); + op->SetAttr("bias", bias); + op->SetAttr("bias_after_scale", bias_after); + op->SetOutput("Out", {out->Name()}); + return out; + } + std::vector batch_norm(VarDesc* x, VarDesc* scale, VarDesc* bias, VarDesc* mean, VarDesc* variance) { VarDesc* y = lod_tensor(unique_name()); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index e436367872b..b8afbb4099d 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -106,12 +106,13 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { // "identity_scale_op_clean_pass", // "is_test_pass", // "simplify_with_basic_ops_pass", // - "fc_fuse_pass", // - "fc_elementwise_layernorm_fuse_pass", // "conv_affine_channel_fuse_pass", // "conv_eltwiseadd_affine_channel_fuse_pass", // "conv_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", // + "multihead_matmul_fuse_pass", + "fc_fuse_pass", // + "fc_elementwise_layernorm_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be // guaranteed at least v7 "conv_elementwise_add_act_fuse_pass", // -- GitLab