未验证 提交 14b7e3cf 编写于 作者: P Pei Yang 提交者: GitHub

[Paddle-TRT] TRT inference support for BERT/Transformer in paddle 2.0 api (#31744)

* support multihead_matmul_fuse_pass_v3

* fix compile problems

* embedding_eltwise_ln pass support lookup_table_v2

* suppoort matmul and matmul_v2 in qkv matmul
上级 245252b8
......@@ -34,15 +34,19 @@ namespace patterns {
static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
const std::string& arg,
bool is_persist = false) {
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
PDNode* node =
pattern->NewNode(name)->assert_is_op_input("lookup_table", arg);
pattern->NewNode(name)->assert_is_ops_input(embedding_ops, arg);
if (is_persist) return node->assert_is_persistable_var();
return node;
}
static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name,
const std::string& arg) {
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
PDNode* node = pattern->NewNode(name)
->assert_is_only_output_of_op("lookup_table")
->assert_is_only_output_of_ops(embedding_ops)
->assert_is_op_input("elementwise_add", arg)
->AsIntermediate();
return node;
......@@ -56,10 +60,12 @@ void Embedding2Eltwise1Pattern::operator()() {
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
auto* lookup_table2_w =
create_emb_vars(pattern, lookup_table2_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_op("lookup_table");
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
auto* lookup_table2 =
pattern->NewNode(lookup_table2_repr())->assert_is_op("lookup_table");
pattern->NewNode(lookup_table2_repr())->assert_is_ops(embedding_ops);
auto* lookup_table1_out =
create_emb_out_vars(pattern, lookup_table1_out_repr(), "X");
auto* lookup_table2_out =
......@@ -80,8 +86,10 @@ void Embedding1Eltwise1Pattern::operator()() {
create_emb_vars(pattern, lookup_table1_x_repr(), "Ids");
auto* lookup_table1_w =
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_op("lookup_table");
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
auto* lookup_table1_out =
create_emb_out_vars(pattern, lookup_table1_out_repr(), "Y");
auto* eltwise_add =
......@@ -347,4 +355,5 @@ REGISTER_PASS_CAPABILITY(embedding_eltwise_layernorm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("lookup_table", 0)
.LE("lookup_table_v2", 1)
.EQ("elementweise_add", 0));
......@@ -652,6 +652,36 @@ PDNode *PDNode::assert_is_ops_input(
return this;
}
PDNode *PDNode::assert_is_only_input_of_ops(
const std::unordered_set<std::string> &op_types) {
assert_is_var();
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->outputs) {
if (op && op->IsOp() && op->Op() && op_types.count(op->Op()->Type()) &&
op->inputs.size() == 1) {
return true;
}
}
return false;
});
return this;
}
PDNode *PDNode::assert_is_only_output_of_ops(
const std::unordered_set<std::string> &op_types) {
assert_is_var();
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->inputs) {
if (op && op->IsOp() && op->Op() && op_types.count(op->Op()->Type()) &&
op->outputs.size() == 1) {
return true;
}
}
return false;
});
return this;
}
bool VarLinksToOp(Node *node, const std::string &op_type) {
for (auto *out : node->outputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
......
......@@ -145,6 +145,11 @@ struct PDNode {
const std::unordered_set<std::string>& op_types,
const std::string& argument, int nth);
PDNode* assert_is_only_input_of_ops(
const std::unordered_set<std::string>& op_types);
PDNode* assert_is_only_output_of_ops(
const std::unordered_set<std::string>& op_types);
PDNode* assert_has_n_inputs(size_t n);
PDNode* assert_has_n_outputs(size_t n);
......
......@@ -89,9 +89,63 @@ struct MultiHeadMatmulPattern : public PatternBase {
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
};
struct MultiHeadMatmulV3Pattern : public PatternBase {
MultiHeadMatmulV3Pattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "multihead_matmul_v3") {}
PDNode* operator()();
// declare operator node's name
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(mul2);
PATTERN_DECL_NODE(mul0_w);
PATTERN_DECL_NODE(mul1_w);
PATTERN_DECL_NODE(mul2_w);
PATTERN_DECL_NODE(mul0_out);
PATTERN_DECL_NODE(mul1_out);
PATTERN_DECL_NODE(mul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(transpose2_qkv_out);
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
};
} // namespace patterns
// The MulGRUFusePass and MulGRUFusePass will fuse to the same FusionGRU op.
class MultiHeadMatmulFusePass : public FusePassBase {
public:
virtual ~MultiHeadMatmulFusePass() {}
......@@ -112,6 +166,16 @@ class MultiHeadMatmulV2FusePass : public FusePassBase {
const std::string name_scope_{"multihead_matmul_fuse_v2"};
};
class MultiHeadMatmulV3FusePass : public FusePassBase {
public:
virtual ~MultiHeadMatmulV3FusePass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"multihead_matmul_fuse_v3"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -86,6 +86,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"simplify_with_basic_ops_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", //
"multihead_matmul_fuse_pass_v3", //
"skip_layernorm_fuse_pass", //
"conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", //
......@@ -235,8 +236,8 @@ void CpuPassStrategy::EnableMKLDNN() {
"reshape_transpose_matmul_mkldnn_fuse_pass", //
"matmul_transpose_reshape_fuse_pass", //
// Disabled due to topology-dependent speed-up
//"fc_mkldnn_pass",
//"fc_act_mkldnn_fuse_pass",
// "fc_mkldnn_pass",
// "fc_act_mkldnn_fuse_pass",
"batch_norm_act_fuse_pass",
// TODO(intel): Please fix the bug on windows.
// https://github.com/PaddlePaddle/Paddle/issues/29710
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册