未验证 提交 b8333ede 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

Add Multihead matmul fuse pass (#20167)

* Add multihead fuse pass for ernie opt

* Refine softmax

test=develop

* Refine cuda kernel

* Refine cuda version

* Refine cmake

test=develop

* refine header file

* refine test case and pass
* refine comments
上级 fc6ec3b9
......@@ -74,6 +74,7 @@ pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference)
pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base)
pass_library(multihead_matmul_fuse_pass inference)
if(WITH_GPU)
pass_library(cudnn_placement_pass base DEPS placement_pass_base)
endif()
......@@ -126,6 +127,7 @@ cc_test(test_repeated_fc_relu_fuse_pass SRCS repeated_fc_relu_fuse_pass_tester.c
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
cc_test(test_simplify_with_basic_ops_pass SRCS simplify_with_basic_ops_pass_tester.cc DEPS simplify_with_basic_ops_pass)
cc_test(test_fc_elementwise_layernorm_fuse_pass SRCS fc_elementwise_layernorm_fuse_pass_tester.cc DEPS fc_elementwise_layernorm_fuse_pass)
cc_test(test_multihead_matmul_fuse_pass SRCS multihead_matmul_fuse_pass_tester.cc DEPS multihead_matmul_fuse_pass)
cc_test(test_conv_bn_fuse_pass SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_fuse_pass)
if(WITH_GPU)
cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h"
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
static void ReplaceOutputVar(Node* op, Node* old_var, Node* new_var) {
if (op->IsOp() && op->Op()) {
new_var->inputs.push_back(op);
for (size_t i = 0; i < op->outputs.size(); ++i) {
if (op->outputs[i] == old_var) {
op->outputs[i] = new_var;
op->Op()->RenameOutput(old_var->Name(), new_var->Name());
}
}
}
}
static int BuildFusion(Graph* graph, const std::string& name_scope) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
MultiHeadMatmulPattern multihead_pattern(pattern, name_scope);
PDNode* x =
pattern->NewNode(patterns::UniqueKey("X"))->assert_var_not_persistable();
multihead_pattern(x);
// Create New OpDesc
auto fuse_creater = [&](
Node* x, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
Node* mul1_out, Node* mul2_out, Node* eltadd0_b, Node* eltadd1_b,
Node* eltadd2_b, Node* eltadd_qk_b, Node* reshape2,
Node* reshape2_qkv_out, Node* scale, Node* scale_out) {
auto scale_attr = boost::get<float>(scale->Op()->GetAttr("scale"));
// auto scale_bias = boost::get<float>(scale->Op()->GetAttr("bias"));
// bool after_scale =
// boost::get<bool>(scale->Op()->GetAttr("bias_after_scale"));
// create multihead
OpDesc multihead_op_desc;
// create tmp tensor
VarDesc k_var_desc(*mul1_out->Var());
k_var_desc.SetName("K" + mul1_out->Name());
auto* k_var_node = graph->CreateVarNode(&k_var_desc);
VarDesc q_var_desc(*mul0_out->Var());
q_var_desc.SetName("Q" + mul0_out->Name());
auto* q_var_node = graph->CreateVarNode(&q_var_desc);
VarDesc v_var_desc(*mul2_out->Var());
v_var_desc.SetName("V" + mul2_out->Name());
auto* v_var_node = graph->CreateVarNode(&v_var_desc);
auto reshape_desc = reshape2->Op();
int head_number =
boost::get<std::vector<int>>(reshape_desc->GetAttr("shape")).at(2);
ReplaceOutputVar(mul0, mul0_out, q_var_node);
ReplaceOutputVar(mul1, mul1_out, k_var_node);
ReplaceOutputVar(mul2, mul2_out, v_var_node);
multihead_op_desc.SetType("multihead_matmul");
multihead_op_desc.SetInput("Q", {q_var_node->Name()});
multihead_op_desc.SetInput("K", {k_var_node->Name()});
multihead_op_desc.SetInput("V", {v_var_node->Name()});
multihead_op_desc.SetInput("BiasQ", {eltadd0_b->Name()});
multihead_op_desc.SetInput("BiasK", {eltadd1_b->Name()});
multihead_op_desc.SetInput("BiasV", {eltadd2_b->Name()});
multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()});
multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()});
multihead_op_desc.SetAttr("alpha", scale_attr);
multihead_op_desc.SetAttr("head_number", head_number);
auto* multihead = graph->CreateOpNode(&multihead_op_desc);
IR_NODE_LINK_TO(q_var_node, multihead);
IR_NODE_LINK_TO(k_var_node, multihead);
IR_NODE_LINK_TO(v_var_node, multihead);
IR_NODE_LINK_TO(eltadd0_b, multihead);
IR_NODE_LINK_TO(eltadd1_b, multihead);
IR_NODE_LINK_TO(eltadd2_b, multihead);
IR_NODE_LINK_TO(eltadd_qk_b, multihead);
IR_NODE_LINK_TO(multihead, reshape2_qkv_out);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out,
multihead_pattern);
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd0_b, eltadd0_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd0_out, eltadd0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1, eltadd1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1_b, eltadd1_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1_out, eltadd1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2, eltadd2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2_b, eltadd2_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2_out, eltadd2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qk, matmul_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qk_out, matmul_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk, eltadd_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_b, eltadd_qk_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern);
fuse_creater(layer_norm, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, reshape2_0,
reshape2_qkv_out, scale, scale_out);
std::unordered_set<const Node*> marked_nodes(
{eltadd0,
eltadd1,
eltadd2,
eltadd0_out,
eltadd1_out,
eltadd2_out,
reshape2_0,
reshape2_1,
reshape2_2,
reshape2_0_out,
reshape2_1_out,
reshape2_2_out,
transpose2_0,
transpose2_1,
transpose2_2,
transpose2_0_out,
transpose2_1_out,
transpose2_2_out,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out, // dropout_qk, dropout_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
mul0_out,
mul1_out,
mul2_out,
reshape2_qkv,
scale});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) {
// Create shared nodes.
auto* layer_norm = pattern->NewNode(layer_norm_repr());
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr());
layer_norm_out_var->assert_is_op_input("mul");
// First path with scale
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul");
auto* mul0_w_var = pattern->NewNode(mul0_w_repr())
->AsInput()
->assert_is_op_input("mul", "Y");
auto* mul0_out_var =
pattern->NewNode(mul0_out_repr())->assert_is_op_output("mul");
decltype(mul0) eltadd0;
decltype(mul0) eltadd0_b_var;
decltype(mul0) eltadd0_out_var;
mul0_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd0 = pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add");
eltadd0_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var =
pattern->NewNode(reshape2_0_out_repr())->assert_is_op_output("reshape2");
reshape2_0_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2");
transpose2_0_out_var->AsIntermediate()->assert_is_op_input("scale");
auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale");
auto* scale_out_var =
pattern->NewNode(scale_out_repr())->assert_is_op_output("scale");
scale_out_var->AsIntermediate()->assert_is_op_input("matmul");
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add");
eltadd_qk_out_var->AsIntermediate()->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var =
pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax");
softmax_qk_out_var->AsIntermediate()->assert_is_op_input("matmul");
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_op("matmul");
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_op_output("matmul");
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2");
reshape2_qkv_out_var->assert_is_op_input("mul");
// Second path to matmul
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("mul");
auto* mul1_w_var = pattern->NewNode(mul1_w_repr())
->AsInput()
->assert_is_op_input("mul", "Y");
auto* mul1_out_var =
pattern->NewNode(mul1_out_repr())->assert_is_op_output("mul");
decltype(mul1) eltadd1;
decltype(mul1) eltadd1_b_var;
decltype(mul1) eltadd1_out_var;
mul1_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd1 = pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add");
eltadd1_b_var = pattern->NewNode(eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd1_out_var = pattern->NewNode(eltadd1_out_repr())
->assert_is_op_output("elementwise_add");
eltadd1_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_1 =
pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2");
auto* reshape2_1_out_var =
pattern->NewNode(reshape2_1_out_repr())->assert_is_op_output("reshape2");
reshape2_1_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_1 =
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2");
transpose2_1_out_var->AsIntermediate()->assert_is_op_input(
"matmul"); // link to matmul qk
// Third path to matmul
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("mul");
auto* mul2_w_var = pattern->NewNode(mul2_w_repr())
->AsInput()
->assert_is_op_input("mul", "Y");
auto* mul2_out_var =
pattern->NewNode(mul2_out_repr())->assert_is_op_output("mul");
decltype(mul2) eltadd2;
decltype(mul2) eltadd2_b_var;
decltype(mul2) eltadd2_out_var;
mul2_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd2 = pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add");
eltadd2_b_var = pattern->NewNode(eltadd2_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd2_out_var = pattern->NewNode(eltadd2_out_repr())
->assert_is_op_output("elementwise_add");
eltadd2_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_2 =
pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2");
auto* reshape2_2_out_var =
pattern->NewNode(reshape2_2_out_repr())->assert_is_op_output("reshape2");
reshape2_2_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_2 =
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2");
transpose2_2_out_var->AsIntermediate()->assert_is_op_input(
"matmul"); // link to matmul qkv
// Link all nodes together
layer_norm->LinksFrom({x}).LinksTo({layer_norm_out_var});
// Q path
mul0->LinksFrom({layer_norm_out_var, mul0_w_var}).LinksTo({mul0_out_var});
eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
scale->LinksFrom({transpose2_0_out_var}).LinksTo({scale_out_var});
// K path
mul1->LinksFrom({layer_norm_out_var, mul1_w_var}).LinksTo({mul1_out_var});
eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var});
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var});
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
// compute q*k
matmul_qk->LinksFrom({scale_out_var, transpose2_1_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
// V path
mul2->LinksFrom({layer_norm_out_var, mul2_w_var}).LinksTo({mul2_out_var});
eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var});
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
// compute q*k*v
matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
return transpose2_2_out_var;
}
} // namespace patterns
void MultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph);
FusePassBase::Init(name_scope_, graph);
int fusion_count = patterns::BuildFusion(graph, name_scope_);
AddStatis(fusion_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(multihead_matmul_fuse_pass,
paddle::framework::ir::MultiHeadMatmulFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct MultiHeadMatmulPattern : public PatternBase {
MultiHeadMatmulPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "multihead_matmul") {}
PDNode* operator()(PDNode* x);
// declare operator node's name
// PATTERN_DECL_NODE(dropout);
// PATTERN_DECL_NODE(dropout_out);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(mul2);
PATTERN_DECL_NODE(mul0_w);
PATTERN_DECL_NODE(mul1_w);
PATTERN_DECL_NODE(mul2_w);
PATTERN_DECL_NODE(mul0_out);
PATTERN_DECL_NODE(mul1_out);
PATTERN_DECL_NODE(mul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(transpose2_qkv_out);
PATTERN_DECL_NODE(scale);
PATTERN_DECL_NODE(scale_out);
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
// PATTERN_DECL_NODE(dropout_qk);
// PATTERN_DECL_NODE(dropout_qk_out);
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
};
} // namespace patterns
// The MulGRUFusePass and MulGRUFusePass will fuse to the same FusionGRU op.
class MultiHeadMatmulFusePass : public FusePassBase {
public:
virtual ~MultiHeadMatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"multihead_matmul_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h" // NOLINT
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(MultiHeadMatmulFusePass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (x) layer_norm -> layer_norm_out
// (layer_norm_out, weights_0) mul -> mul_out0
// (layer_norm_out, weights_1) mul -> mul_out1
// (layer_norm_out, weights_2) mul -> mul_out2
// (mul_out0, bias_0) elementweise_add -> eltadd_0
// (mul_out1, bias_1) elementweise_add -> eltadd_1
// (mul_out2, bias_2) elementweise_add -> eltadd_2
// (eltadd_0) reshape2 -> reshape_0
// (eltadd_1) reshape2 -> reshape_1
// (eltadd_2) reshape2 -> reshape_2
// (reshape_0) transpose2 -> transpose_0
// (reshape_1) transpose2 -> transpose_1
// (reshape_2) transpose2 -> transpose_2
// (transpose_0) scale -> scale_0
// (scale_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) mul -> mul_qkv
Layers layers;
auto* x = layers.data("x", {128, 768});
auto out = layers.layer_norm(x);
auto* layer_out = out[0];
auto* weights_0 = layers.data("weights0", {768, 768}, true);
auto* weights_1 = layers.data("weights1", {768, 768}, true);
auto* weights_2 = layers.data("weights2", {768, 768}, true);
auto* mul_out_0 = layers.mul(layer_out, weights_0);
auto* mul_out_1 = layers.mul(layer_out, weights_1);
auto* mul_out_2 = layers.mul(layer_out, weights_2);
auto* b0 = layers.data("bias_0", {768}, true);
auto* b1 = layers.data("bias_1", {768}, true);
auto* b2 = layers.data("bias_2", {768}, true);
auto* elementwise_out_0 = layers.elementwise_add(mul_out_0, b0);
auto* elementwise_out_1 = layers.elementwise_add(mul_out_1, b1);
auto* elementwise_out_2 = layers.elementwise_add(mul_out_2, b2);
std::vector<int> shape = {128, 12, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape);
auto* reshape_1 = layers.reshape2(elementwise_out_1, shape);
auto* reshape_2 = layers.reshape2(elementwise_out_2, shape);
std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis);
auto* transpose_1 = layers.transpose2(reshape_1, axis);
auto* transpose_2 = layers.transpose2(reshape_2, axis);
auto* scale_0 = layers.scale(transpose_0, 0.125, 0, false);
auto* matmul_qk = layers.matmul(scale_0, transpose_1);
auto* bqk = layers.data("biasqk", {768}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* matmul_qkv = layers.matmul(softmax_qk, transpose_2);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3});
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {128, 768});
auto* weights_l = layers.data("weightsl", {768, 768}, true);
layers.mul(reshape_qkv_out, weights_l);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("multihead_matmul_fuse_pass");
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
int num_fused_nodes_after = GetNumOpNodes(graph, "multihead_matmul");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 29);
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(multihead_matmul_fuse_pass);
......@@ -175,6 +175,58 @@ struct Layers {
return outs;
}
VarDesc* matmul(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("matmul");
op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()});
op->SetOutput("Out", {out->Name()});
return out;
}
VarDesc* transpose2(VarDesc* x, std::vector<int> axis) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("transpose2");
op->SetInput("X", {x->Name()});
op->SetAttr("axis", axis);
op->SetOutput("Out", {out->Name()});
return out;
}
VarDesc* reshape2(VarDesc* x, std::vector<int> shape) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("reshape2");
op->SetInput("X", {x->Name()});
op->SetAttr("shape", shape);
op->SetOutput("Out", {out->Name()});
return out;
}
VarDesc* softmax(VarDesc* x, int axis) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("softmax");
op->SetInput("X", {x->Name()});
op->SetAttr("axis", axis);
op->SetOutput("Out", {out->Name()});
return out;
}
VarDesc* scale(VarDesc* x, float scale, float bias, bool bias_after) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("scale");
op->SetInput("X", {x->Name()});
op->SetAttr("scale", scale);
op->SetAttr("bias", bias);
op->SetAttr("bias_after_scale", bias_after);
op->SetOutput("Out", {out->Name()});
return out;
}
std::vector<VarDesc*> batch_norm(VarDesc* x, VarDesc* scale, VarDesc* bias,
VarDesc* mean, VarDesc* variance) {
VarDesc* y = lod_tensor(unique_name());
......
......@@ -106,12 +106,13 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
// "identity_scale_op_clean_pass", //
"is_test_pass", //
"simplify_with_basic_ops_pass", //
"fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", //
"conv_affine_channel_fuse_pass", //
"conv_eltwiseadd_affine_channel_fuse_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
"multihead_matmul_fuse_pass",
"fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7
"conv_elementwise_add_act_fuse_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册