未验证 提交 14b7e3cf 编写于 作者: P Pei Yang 提交者: GitHub

[Paddle-TRT] TRT inference support for BERT/Transformer in paddle 2.0 api (#31744)

* support multihead_matmul_fuse_pass_v3

* fix compile problems

* embedding_eltwise_ln pass support lookup_table_v2

* suppoort matmul and matmul_v2 in qkv matmul
上级 245252b8
...@@ -34,15 +34,19 @@ namespace patterns { ...@@ -34,15 +34,19 @@ namespace patterns {
static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name, static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
const std::string& arg, const std::string& arg,
bool is_persist = false) { bool is_persist = false) {
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
PDNode* node = PDNode* node =
pattern->NewNode(name)->assert_is_op_input("lookup_table", arg); pattern->NewNode(name)->assert_is_ops_input(embedding_ops, arg);
if (is_persist) return node->assert_is_persistable_var(); if (is_persist) return node->assert_is_persistable_var();
return node; return node;
} }
static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name, static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name,
const std::string& arg) { const std::string& arg) {
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
PDNode* node = pattern->NewNode(name) PDNode* node = pattern->NewNode(name)
->assert_is_only_output_of_op("lookup_table") ->assert_is_only_output_of_ops(embedding_ops)
->assert_is_op_input("elementwise_add", arg) ->assert_is_op_input("elementwise_add", arg)
->AsIntermediate(); ->AsIntermediate();
return node; return node;
...@@ -56,10 +60,12 @@ void Embedding2Eltwise1Pattern::operator()() { ...@@ -56,10 +60,12 @@ void Embedding2Eltwise1Pattern::operator()() {
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true); create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
auto* lookup_table2_w = auto* lookup_table2_w =
create_emb_vars(pattern, lookup_table2_w_repr(), "W", true); create_emb_vars(pattern, lookup_table2_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
auto* lookup_table1 = auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_op("lookup_table"); pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
auto* lookup_table2 = auto* lookup_table2 =
pattern->NewNode(lookup_table2_repr())->assert_is_op("lookup_table"); pattern->NewNode(lookup_table2_repr())->assert_is_ops(embedding_ops);
auto* lookup_table1_out = auto* lookup_table1_out =
create_emb_out_vars(pattern, lookup_table1_out_repr(), "X"); create_emb_out_vars(pattern, lookup_table1_out_repr(), "X");
auto* lookup_table2_out = auto* lookup_table2_out =
...@@ -80,8 +86,10 @@ void Embedding1Eltwise1Pattern::operator()() { ...@@ -80,8 +86,10 @@ void Embedding1Eltwise1Pattern::operator()() {
create_emb_vars(pattern, lookup_table1_x_repr(), "Ids"); create_emb_vars(pattern, lookup_table1_x_repr(), "Ids");
auto* lookup_table1_w = auto* lookup_table1_w =
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true); create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
auto* lookup_table1 = auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_op("lookup_table"); pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
auto* lookup_table1_out = auto* lookup_table1_out =
create_emb_out_vars(pattern, lookup_table1_out_repr(), "Y"); create_emb_out_vars(pattern, lookup_table1_out_repr(), "Y");
auto* eltwise_add = auto* eltwise_add =
...@@ -347,4 +355,5 @@ REGISTER_PASS_CAPABILITY(embedding_eltwise_layernorm_fuse_pass) ...@@ -347,4 +355,5 @@ REGISTER_PASS_CAPABILITY(embedding_eltwise_layernorm_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("lookup_table", 0) .EQ("lookup_table", 0)
.LE("lookup_table_v2", 1)
.EQ("elementweise_add", 0)); .EQ("elementweise_add", 0));
...@@ -652,6 +652,36 @@ PDNode *PDNode::assert_is_ops_input( ...@@ -652,6 +652,36 @@ PDNode *PDNode::assert_is_ops_input(
return this; return this;
} }
PDNode *PDNode::assert_is_only_input_of_ops(
const std::unordered_set<std::string> &op_types) {
assert_is_var();
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->outputs) {
if (op && op->IsOp() && op->Op() && op_types.count(op->Op()->Type()) &&
op->inputs.size() == 1) {
return true;
}
}
return false;
});
return this;
}
PDNode *PDNode::assert_is_only_output_of_ops(
const std::unordered_set<std::string> &op_types) {
assert_is_var();
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->inputs) {
if (op && op->IsOp() && op->Op() && op_types.count(op->Op()->Type()) &&
op->outputs.size() == 1) {
return true;
}
}
return false;
});
return this;
}
bool VarLinksToOp(Node *node, const std::string &op_type) { bool VarLinksToOp(Node *node, const std::string &op_type) {
for (auto *out : node->outputs) { for (auto *out : node->outputs) {
if (out->IsOp() && out->Op()->Type() == op_type) { if (out->IsOp() && out->Op()->Type() == op_type) {
......
...@@ -145,6 +145,11 @@ struct PDNode { ...@@ -145,6 +145,11 @@ struct PDNode {
const std::unordered_set<std::string>& op_types, const std::unordered_set<std::string>& op_types,
const std::string& argument, int nth); const std::string& argument, int nth);
PDNode* assert_is_only_input_of_ops(
const std::unordered_set<std::string>& op_types);
PDNode* assert_is_only_output_of_ops(
const std::unordered_set<std::string>& op_types);
PDNode* assert_has_n_inputs(size_t n); PDNode* assert_has_n_inputs(size_t n);
PDNode* assert_has_n_outputs(size_t n); PDNode* assert_has_n_outputs(size_t n);
......
...@@ -89,9 +89,63 @@ struct MultiHeadMatmulPattern : public PatternBase { ...@@ -89,9 +89,63 @@ struct MultiHeadMatmulPattern : public PatternBase {
PATTERN_DECL_NODE(matmul_qkv); PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out); PATTERN_DECL_NODE(matmul_qkv_out);
}; };
struct MultiHeadMatmulV3Pattern : public PatternBase {
MultiHeadMatmulV3Pattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "multihead_matmul_v3") {}
PDNode* operator()();
// declare operator node's name
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(mul2);
PATTERN_DECL_NODE(mul0_w);
PATTERN_DECL_NODE(mul1_w);
PATTERN_DECL_NODE(mul2_w);
PATTERN_DECL_NODE(mul0_out);
PATTERN_DECL_NODE(mul1_out);
PATTERN_DECL_NODE(mul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(transpose2_qkv_out);
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
};
} // namespace patterns } // namespace patterns
// The MulGRUFusePass and MulGRUFusePass will fuse to the same FusionGRU op.
class MultiHeadMatmulFusePass : public FusePassBase { class MultiHeadMatmulFusePass : public FusePassBase {
public: public:
virtual ~MultiHeadMatmulFusePass() {} virtual ~MultiHeadMatmulFusePass() {}
...@@ -112,6 +166,16 @@ class MultiHeadMatmulV2FusePass : public FusePassBase { ...@@ -112,6 +166,16 @@ class MultiHeadMatmulV2FusePass : public FusePassBase {
const std::string name_scope_{"multihead_matmul_fuse_v2"}; const std::string name_scope_{"multihead_matmul_fuse_v2"};
}; };
class MultiHeadMatmulV3FusePass : public FusePassBase {
public:
virtual ~MultiHeadMatmulV3FusePass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"multihead_matmul_fuse_v3"};
};
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -86,6 +86,7 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -86,6 +86,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"simplify_with_basic_ops_pass", // "simplify_with_basic_ops_pass", //
"embedding_eltwise_layernorm_fuse_pass", // "embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", // "multihead_matmul_fuse_pass_v2", //
"multihead_matmul_fuse_pass_v3", //
"skip_layernorm_fuse_pass", // "skip_layernorm_fuse_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", // "unsqueeze2_eltwise_fuse_pass", //
...@@ -235,8 +236,8 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -235,8 +236,8 @@ void CpuPassStrategy::EnableMKLDNN() {
"reshape_transpose_matmul_mkldnn_fuse_pass", // "reshape_transpose_matmul_mkldnn_fuse_pass", //
"matmul_transpose_reshape_fuse_pass", // "matmul_transpose_reshape_fuse_pass", //
// Disabled due to topology-dependent speed-up // Disabled due to topology-dependent speed-up
//"fc_mkldnn_pass", // "fc_mkldnn_pass",
//"fc_act_mkldnn_fuse_pass", // "fc_act_mkldnn_fuse_pass",
"batch_norm_act_fuse_pass", "batch_norm_act_fuse_pass",
// TODO(intel): Please fix the bug on windows. // TODO(intel): Please fix the bug on windows.
// https://github.com/PaddlePaddle/Paddle/issues/29710 // https://github.com/PaddlePaddle/Paddle/issues/29710
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册