未验证 提交 967f4c2e 编写于 作者: P Pei Yang 提交者: GitHub

Cherry pick bert transformer 2.0 support (#31959)

* [Paddle-TRT] TRT inference support for BERT/Transformer in paddle 2.0 api (#31744)

* support multihead_matmul_fuse_pass_v3

* fix compile problems

* embedding_eltwise_ln pass support lookup_table_v2

* suppoort matmul and matmul_v2 in qkv matmul

* map_matmul_to_mul_pass support 3dim
上级 b655bee2
......@@ -29,15 +29,19 @@ namespace patterns {
static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
const std::string& arg,
bool is_persist = false) {
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
PDNode* node =
pattern->NewNode(name)->assert_is_op_input("lookup_table", arg);
pattern->NewNode(name)->assert_is_ops_input(embedding_ops, arg);
if (is_persist) return node->assert_is_persistable_var();
return node;
}
static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name,
const std::string& arg) {
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
PDNode* node = pattern->NewNode(name)
->assert_is_only_output_of_op("lookup_table")
->assert_is_only_output_of_ops(embedding_ops)
->assert_is_op_input("elementwise_add", arg)
->AsIntermediate();
return node;
......@@ -51,10 +55,12 @@ void Embedding2Eltwise1Pattern::operator()() {
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
auto* lookup_table2_w =
create_emb_vars(pattern, lookup_table2_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_op("lookup_table");
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
auto* lookup_table2 =
pattern->NewNode(lookup_table2_repr())->assert_is_op("lookup_table");
pattern->NewNode(lookup_table2_repr())->assert_is_ops(embedding_ops);
auto* lookup_table1_out =
create_emb_out_vars(pattern, lookup_table1_out_repr(), "X");
auto* lookup_table2_out =
......@@ -75,8 +81,10 @@ void Embedding1Eltwise1Pattern::operator()() {
create_emb_vars(pattern, lookup_table1_x_repr(), "Ids");
auto* lookup_table1_w =
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_op("lookup_table");
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
auto* lookup_table1_out =
create_emb_out_vars(pattern, lookup_table1_out_repr(), "Y");
auto* eltwise_add =
......@@ -342,4 +350,5 @@ REGISTER_PASS_CAPABILITY(embedding_eltwise_layernorm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("lookup_table", 0)
.LE("lookup_table_v2", 1)
.EQ("elementweise_add", 0));
......@@ -662,6 +662,36 @@ PDNode *PDNode::assert_is_ops_input(
return this;
}
PDNode *PDNode::assert_is_only_input_of_ops(
const std::unordered_set<std::string> &op_types) {
assert_is_var();
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->outputs) {
if (op && op->IsOp() && op->Op() && op_types.count(op->Op()->Type()) &&
op->inputs.size() == 1) {
return true;
}
}
return false;
});
return this;
}
PDNode *PDNode::assert_is_only_output_of_ops(
const std::unordered_set<std::string> &op_types) {
assert_is_var();
asserts_.emplace_back([=](Node *x) {
for (auto *op : x->inputs) {
if (op && op->IsOp() && op->Op() && op_types.count(op->Op()->Type()) &&
op->outputs.size() == 1) {
return true;
}
}
return false;
});
return this;
}
bool VarLinksToOp(Node *node, const std::string &op_type) {
for (auto *out : node->outputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
......
......@@ -146,6 +146,11 @@ struct PDNode {
const std::unordered_set<std::string>& op_types,
const std::string& argument, int nth);
PDNode* assert_is_only_input_of_ops(
const std::unordered_set<std::string>& op_types);
PDNode* assert_is_only_output_of_ops(
const std::unordered_set<std::string>& op_types);
PDNode* assert_has_n_inputs(size_t n);
PDNode* assert_has_n_outputs(size_t n);
......
......@@ -57,7 +57,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
std::vector<int64_t> y_shape = matmul_in_y->Var()->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
flag = flag && x_rank == 2 && y_rank == 2;
flag = flag && (x_rank == 2 || x_rank == 3) && y_rank == 2;
std::vector<Node*>& next_ops = matmul_out->outputs;
flag = flag && next_ops.size() == 1 &&
......@@ -69,7 +69,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
desc.SetInput("X", {matmul_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()});
desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("x_num_col_dims", 1);
desc.SetAttr("x_num_col_dims", static_cast<int>(x_rank - 1));
desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
......
......@@ -89,9 +89,63 @@ struct MultiHeadMatmulPattern : public PatternBase {
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
};
struct MultiHeadMatmulV3Pattern : public PatternBase {
MultiHeadMatmulV3Pattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "multihead_matmul_v3") {}
PDNode* operator()();
// declare operator node's name
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(mul2);
PATTERN_DECL_NODE(mul0_w);
PATTERN_DECL_NODE(mul1_w);
PATTERN_DECL_NODE(mul2_w);
PATTERN_DECL_NODE(mul0_out);
PATTERN_DECL_NODE(mul1_out);
PATTERN_DECL_NODE(mul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(transpose2_qkv_out);
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
};
} // namespace patterns
// The MulGRUFusePass and MulGRUFusePass will fuse to the same FusionGRU op.
class MultiHeadMatmulFusePass : public FusePassBase {
public:
virtual ~MultiHeadMatmulFusePass() {}
......@@ -112,6 +166,16 @@ class MultiHeadMatmulV2FusePass : public FusePassBase {
const std::string name_scope_{"multihead_matmul_fuse_v2"};
};
class MultiHeadMatmulV3FusePass : public FusePassBase {
public:
virtual ~MultiHeadMatmulV3FusePass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"multihead_matmul_fuse_v3"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -82,6 +82,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"simplify_with_basic_ops_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", //
"multihead_matmul_fuse_pass_v3", //
"skip_layernorm_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", //
"conv_bn_fuse_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册