未验证 提交 14b7e3cf 编写于 作者: P Pei Yang 提交者: GitHub

[Paddle-TRT] TRT inference support for BERT/Transformer in paddle 2.0 api (#31744)

* support multihead_matmul_fuse_pass_v3

* fix compile problems

* embedding_eltwise_ln pass support lookup_table_v2

* suppoort matmul and matmul_v2 in qkv matmul
上级 245252b8
......@@ -34,15 +34,19 @@ namespace patterns {
static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
const std::string& arg,
bool is_persist = false) {
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
PDNode* node =
pattern->NewNode(name)->assert_is_op_input("lookup_table", arg);
pattern->NewNode(name)->assert_is_ops_input(embedding_ops, arg);
if (is_persist) return node->assert_is_persistable_var();
return node;
}
static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name,
const std::string& arg) {
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
PDNode* node = pattern->NewNode(name)
->assert_is_only_output_of_op("lookup_table")
->assert_is_only_output_of_ops(embedding_ops)
->assert_is_op_input("elementwise_add", arg)
->AsIntermediate();
return node;
......@@ -56,10 +60,12 @@ void Embedding2Eltwise1Pattern::operator()() {
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
auto* lookup_table2_w =
create_emb_vars(pattern, lookup_table2_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_op("lookup_table");
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
auto* lookup_table2 =
pattern->NewNode(lookup_table2_repr())->assert_is_op("lookup_table");
pattern->NewNode(lookup_table2_repr())->assert_is_ops(embedding_ops);
auto* lookup_table1_out =
create_emb_out_vars(pattern, lookup_table1_out_repr(), "X");
auto* lookup_table2_out =
......@@ -80,8 +86,10 @@ void Embedding1Eltwise1Pattern::operator()() {
create_emb_vars(pattern, lookup_table1_x_repr(), "Ids");
auto* lookup_table1_w =
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_op("lookup_table");
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
auto* lookup_table1_out =
create_emb_out_vars(pattern, lookup_table1_out_repr(), "Y");
auto* eltwise_add =
......@@ -347,4 +355,5 @@ REGISTER_PASS_CAPABILITY(embedding_eltwise_layernorm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("lookup_table", 0)
.LE("lookup_table_v2", 1)
.EQ("elementweise_add", 0));
......@@ -652,6 +652,36 @@ PDNode *PDNode::assert_is_ops_input(
return this;
}
PDNode *PDNode::assert_is_only_input_of_ops(
const std::unordered_set<std::string> &op_types) {
assert_is_var();
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->outputs) {
if (op && op->IsOp() && op->Op() && op_types.count(op->Op()->Type()) &&
op->inputs.size() == 1) {
return true;
}
}
return false;
});
return this;
}
PDNode *PDNode::assert_is_only_output_of_ops(
const std::unordered_set<std::string> &op_types) {
assert_is_var();
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->inputs) {
if (op && op->IsOp() && op->Op() && op_types.count(op->Op()->Type()) &&
op->outputs.size() == 1) {
return true;
}
}
return false;
});
return this;
}
bool VarLinksToOp(Node *node, const std::string &op_type) {
for (auto *out : node->outputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
......
......@@ -145,6 +145,11 @@ struct PDNode {
const std::unordered_set<std::string>& op_types,
const std::string& argument, int nth);
PDNode* assert_is_only_input_of_ops(
const std::unordered_set<std::string>& op_types);
PDNode* assert_is_only_output_of_ops(
const std::unordered_set<std::string>& op_types);
PDNode* assert_has_n_inputs(size_t n);
PDNode* assert_has_n_outputs(size_t n);
......
......@@ -682,6 +682,447 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
return fusion_count;
}
PDNode* MultiHeadMatmulV3Pattern::operator()() {
std::unordered_set<std::string> matmul_ops{"matmul", "matmul_v2"};
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("matmul");
// First path with scale
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("matmul");
auto* mul0_w_var = pattern->NewNode(mul0_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
auto* mul0_out_var =
pattern->NewNode(mul0_out_repr())->assert_is_op_output("matmul");
decltype(mul0) eltadd0;
decltype(mul0) eltadd0_b_var;
decltype(mul0) eltadd0_out_var;
mul0_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd0 = pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add");
eltadd0_b_var = pattern->NewNode(eltadd0_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd0_out_var = pattern->NewNode(eltadd0_out_repr())
->assert_is_op_output("elementwise_add");
eltadd0_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_0 =
pattern->NewNode(reshape2_0_repr())->assert_is_op("reshape2");
auto* reshape2_0_out_var =
pattern->NewNode(reshape2_0_out_repr())->assert_is_op_output("reshape2");
reshape2_0_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_0 =
pattern->NewNode(transpose2_0_repr())->assert_is_op("transpose2");
auto* transpose2_0_out_var = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2");
transpose2_0_out_var->AsIntermediate()->assert_is_op_input("matmul");
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
auto* eltadd_qk_b_var = pattern->NewNode(eltadd_qk_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto* eltadd_qk_out_var = pattern->NewNode(eltadd_qk_out_repr())
->assert_is_op_output("elementwise_add");
eltadd_qk_out_var->AsIntermediate()->assert_is_op_input("softmax");
auto* softmax_qk =
pattern->NewNode(softmax_qk_repr())->assert_is_op("softmax");
auto* softmax_qk_out_var =
pattern->NewNode(softmax_qk_out_repr())->assert_is_op_output("softmax");
softmax_qk_out_var->AsIntermediate()->assert_is_ops_input(matmul_ops);
auto* matmul_qkv =
pattern->NewNode(matmul_qkv_repr())->assert_is_ops(matmul_ops);
auto* matmul_qkv_out_var =
pattern->NewNode(matmul_qkv_out_repr())->assert_is_ops_output(matmul_ops);
matmul_qkv_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_qkv =
pattern->NewNode(transpose2_qkv_repr())->assert_is_op("transpose2");
auto* transpose2_qkv_out_var = pattern->NewNode(transpose2_qkv_out_repr())
->assert_is_op_output("transpose2");
transpose2_qkv_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_qkv =
pattern->NewNode(reshape2_qkv_repr())->assert_is_op("reshape2");
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2");
reshape2_qkv_out_var->assert_is_op_input("matmul");
// Second path to matmul
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("matmul");
auto* mul1_w_var = pattern->NewNode(mul1_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
auto* mul1_out_var =
pattern->NewNode(mul1_out_repr())->assert_is_op_output("matmul");
decltype(mul1) eltadd1;
decltype(mul1) eltadd1_b_var;
decltype(mul1) eltadd1_out_var;
mul1_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd1 = pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add");
eltadd1_b_var = pattern->NewNode(eltadd1_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd1_out_var = pattern->NewNode(eltadd1_out_repr())
->assert_is_op_output("elementwise_add");
eltadd1_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_1 =
pattern->NewNode(reshape2_1_repr())->assert_is_op("reshape2");
auto* reshape2_1_out_var =
pattern->NewNode(reshape2_1_out_repr())->assert_is_op_output("reshape2");
reshape2_1_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_1 =
pattern->NewNode(transpose2_1_repr())->assert_is_op("transpose2");
auto* transpose2_1_out_var = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2");
transpose2_1_out_var->AsIntermediate()->assert_is_op_input(
"matmul"); // link to matmul qk
// Third path to matmul
auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("matmul");
auto* mul2_w_var = pattern->NewNode(mul2_w_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
auto* mul2_out_var =
pattern->NewNode(mul2_out_repr())->assert_is_op_output("matmul");
decltype(mul2) eltadd2;
decltype(mul2) eltadd2_b_var;
decltype(mul2) eltadd2_out_var;
mul2_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
eltadd2 = pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add");
eltadd2_b_var = pattern->NewNode(eltadd2_b_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
eltadd2_out_var = pattern->NewNode(eltadd2_out_repr())
->assert_is_op_output("elementwise_add");
eltadd2_out_var->AsIntermediate()->assert_is_op_input("reshape2");
auto* reshape2_2 =
pattern->NewNode(reshape2_2_repr())->assert_is_op("reshape2");
auto* reshape2_2_out_var =
pattern->NewNode(reshape2_2_out_repr())->assert_is_op_output("reshape2");
reshape2_2_out_var->AsIntermediate()->assert_is_op_input("transpose2");
auto* transpose2_2 =
pattern->NewNode(transpose2_2_repr())->assert_is_op("transpose2");
auto* transpose2_2_out_var = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2");
transpose2_2_out_var->AsIntermediate()->assert_is_ops_input(
matmul_ops); // link to matmul qkv
// Q path
mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var});
eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
// K path
mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var});
eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var});
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var});
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
// compute q*k
matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
// V path
mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var});
eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var});
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
// compute q*k*v
matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var})
.LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var});
reshape2_qkv->LinksFrom({transpose2_qkv_out_var})
.LinksTo({reshape2_qkv_out_var});
return transpose2_2_out_var;
}
static int BuildFusionV3(Graph* graph, const std::string& name_scope,
Scope* scope) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
MultiHeadMatmulV3Pattern multihead_pattern(pattern, name_scope);
multihead_pattern();
// Create New OpDesc
auto fuse_creater = [&](
Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w,
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b,
Node* reshape2, Node* reshape2_qkv_out, Node* matmul_qk) {
auto scale_attr = BOOST_GET_CONST(float, matmul_qk->Op()->GetAttr("alpha"));
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
// bias (B * S * 3 * N * H) + bias (3 * N * H)
// Transpose (B * S * 3 * N * H) -> (3 * B * N * S * H)
auto* wq_tensor = scope->FindVar(mul0_w->Name())->GetMutable<LoDTensor>();
auto* wk_tensor = scope->FindVar(mul1_w->Name())->GetMutable<LoDTensor>();
auto* wv_tensor = scope->FindVar(mul2_w->Name())->GetMutable<LoDTensor>();
auto* bq_tensor =
scope->FindVar(eltadd0_b->Name())->GetMutable<LoDTensor>();
auto* bk_tensor =
scope->FindVar(eltadd1_b->Name())->GetMutable<LoDTensor>();
auto* bv_tensor =
scope->FindVar(eltadd2_b->Name())->GetMutable<LoDTensor>();
auto* wq_data = wq_tensor->mutable_data<float>(platform::CPUPlace());
auto* wk_data = wk_tensor->mutable_data<float>(platform::CPUPlace());
auto* wv_data = wv_tensor->mutable_data<float>(platform::CPUPlace());
auto* bq_data = bq_tensor->mutable_data<float>(platform::CPUPlace());
auto* bk_data = bk_tensor->mutable_data<float>(platform::CPUPlace());
auto* bv_data = bv_tensor->mutable_data<float>(platform::CPUPlace());
auto combined_w_dims =
framework::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
auto combined_bias_dims = framework::make_ddim({3, bq_tensor->dims()[0]});
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto* combined_w_desc = mul0_w->Var();
combined_w_desc->SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
combined_w_desc->SetPersistable(true);
auto* combined_bias_desc = eltadd0_b->Var();
combined_bias_desc->SetShape({3, bq_tensor->dims()[0]});
combined_bias_desc->SetPersistable(true);
framework::LoDTensor tmp_combined_w_tensor;
tmp_combined_w_tensor.Resize(combined_w_dims);
auto* tmp_combined_w_data =
tmp_combined_w_tensor.mutable_data<float>(platform::CPUPlace());
std::vector<float*> w_vec = {wq_data, wk_data, wv_data};
int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2];
// Combine the three fc weights together.
for (int i = 0; i < dims_h; i++) {
for (int j = 0; j < 3; j++) {
for (int k = 0; k < dims_w; k++) {
int out_index = i * (3 * dims_w) + j * dims_w + k;
int in_index = i * dims_w + k;
tmp_combined_w_data[out_index] = w_vec[j][in_index];
}
}
}
wq_tensor->Resize(combined_w_dims);
auto* new_combined_w_data =
wq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_w_data, tmp_combined_w_data,
sizeof(float) * wq_tensor->numel());
scope->EraseVars({mul1_w->Name(), mul2_w->Name()});
framework::LoDTensor tmp_combined_bias_tensor;
tmp_combined_bias_tensor.Resize(combined_bias_dims);
auto* tmp_combined_bias_data =
tmp_combined_bias_tensor.mutable_data<float>(platform::CPUPlace());
size_t bias_size = bq_tensor->numel();
memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + bias_size, bk_data,
sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + 2 * bias_size, bv_data,
sizeof(float) * bias_size);
bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data =
bq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_bias_data, tmp_combined_bias_data,
sizeof(float) * bq_tensor->numel());
scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});
auto reshape_desc = reshape2->Op();
int head_number =
BOOST_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape")).at(2);
OpDesc multihead_op_desc;
multihead_op_desc.SetType("multihead_matmul");
multihead_op_desc.SetInput("Input", {input0->Name()});
multihead_op_desc.SetInput("W", {mul0_w->Name()});
multihead_op_desc.SetInput("Bias", {eltadd0_b->Name()});
multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()});
multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()});
multihead_op_desc.SetAttr("alpha", scale_attr);
multihead_op_desc.SetAttr("head_number", head_number);
auto* multihead = graph->CreateOpNode(&multihead_op_desc);
IR_NODE_LINK_TO(input0, multihead);
IR_NODE_LINK_TO(mul0_w, multihead);
IR_NODE_LINK_TO(eltadd0_b, multihead);
IR_NODE_LINK_TO(eltadd_qk_b, multihead);
IR_NODE_LINK_TO(multihead, reshape2_qkv_out);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out,
multihead_pattern);
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd0_b, eltadd0_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd0_out, eltadd0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1, eltadd1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1_b, eltadd1_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1_out, eltadd1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2, eltadd2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2_b, eltadd2_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2_out, eltadd2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qk, matmul_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qk_out, matmul_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk, eltadd_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_b, eltadd_qk_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern);
// If weights or biases in qkv's fc are shared by multiple multihead_matmul
// patterns, we do not support this kind of fusion, this pass will not take
// effect.
bool is_fc_params_shared =
mul0_w->outputs.size() > 1 || mul1_w->outputs.size() > 1 ||
mul2_w->outputs.size() > 1 || eltadd0_b->outputs.size() > 1 ||
eltadd1_b->outputs.size() > 1 || eltadd2_b->outputs.size() > 1;
if (is_fc_params_shared) {
return;
}
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w,
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b,
reshape2_0, reshape2_qkv_out, matmul_qk);
std::unordered_set<const Node*> marked_nodes({eltadd0,
eltadd1,
eltadd2,
eltadd1_b,
eltadd2_b,
eltadd0_out,
eltadd1_out,
eltadd2_out,
reshape2_0,
reshape2_1,
reshape2_2,
reshape2_0_out,
reshape2_1_out,
reshape2_2_out,
transpose2_0,
transpose2_1,
transpose2_2,
transpose2_0_out,
transpose2_1_out,
transpose2_2_out,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
mul0,
mul1,
mul2,
mul0_out,
mul1_out,
mul2_out,
mul1_w,
mul2_w,
reshape2_qkv});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
} // namespace patterns
void MultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const {
......@@ -706,6 +1147,21 @@ void MultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const {
AddStatis(fusion_count);
}
void MultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal(
"During the multiheadMatmul pass, The scope should not be null."));
int fusion_count = patterns::BuildFusionV3(graph, name_scope_, scope);
if (fusion_count > 0) {
graph->Set(kMultiheadMatmulPass, new bool(true));
}
AddStatis(fusion_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -715,6 +1171,8 @@ REGISTER_PASS(multihead_matmul_fuse_pass,
REGISTER_PASS(multihead_matmul_fuse_pass_v2,
paddle::framework::ir::MultiHeadMatmulV2FusePass);
REGISTER_PASS(multihead_matmul_fuse_pass_v3,
paddle::framework::ir::MultiHeadMatmulV3FusePass);
REGISTER_PASS_CAPABILITY(multihead_matmul_fuse_pass_v2)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
......@@ -725,3 +1183,13 @@ REGISTER_PASS_CAPABILITY(multihead_matmul_fuse_pass_v2)
.EQ("scale", 0)
.LE("matmul", 1)
.EQ("softmax", 0));
REGISTER_PASS_CAPABILITY(multihead_matmul_fuse_pass_v3)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.LE("matmul", 1)
.EQ("matmul_v2", 0)
.EQ("softmax", 0));
......@@ -89,9 +89,63 @@ struct MultiHeadMatmulPattern : public PatternBase {
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
};
struct MultiHeadMatmulV3Pattern : public PatternBase {
MultiHeadMatmulV3Pattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "multihead_matmul_v3") {}
PDNode* operator()();
// declare operator node's name
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(mul2);
PATTERN_DECL_NODE(mul0_w);
PATTERN_DECL_NODE(mul1_w);
PATTERN_DECL_NODE(mul2_w);
PATTERN_DECL_NODE(mul0_out);
PATTERN_DECL_NODE(mul1_out);
PATTERN_DECL_NODE(mul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(transpose2_qkv_out);
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
};
} // namespace patterns
// The MulGRUFusePass and MulGRUFusePass will fuse to the same FusionGRU op.
class MultiHeadMatmulFusePass : public FusePassBase {
public:
virtual ~MultiHeadMatmulFusePass() {}
......@@ -112,6 +166,16 @@ class MultiHeadMatmulV2FusePass : public FusePassBase {
const std::string name_scope_{"multihead_matmul_fuse_v2"};
};
class MultiHeadMatmulV3FusePass : public FusePassBase {
public:
virtual ~MultiHeadMatmulV3FusePass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"multihead_matmul_fuse_v3"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -86,6 +86,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"simplify_with_basic_ops_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", //
"multihead_matmul_fuse_pass_v3", //
"skip_layernorm_fuse_pass", //
"conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", //
......@@ -235,8 +236,8 @@ void CpuPassStrategy::EnableMKLDNN() {
"reshape_transpose_matmul_mkldnn_fuse_pass", //
"matmul_transpose_reshape_fuse_pass", //
// Disabled due to topology-dependent speed-up
//"fc_mkldnn_pass",
//"fc_act_mkldnn_fuse_pass",
// "fc_mkldnn_pass",
// "fc_act_mkldnn_fuse_pass",
"batch_norm_act_fuse_pass",
// TODO(intel): Please fix the bug on windows.
// https://github.com/PaddlePaddle/Paddle/issues/29710
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册