未验证 提交 29eec2dd 编写于 作者: L lzy 提交者: GitHub

add multi_devices_fused_multi_transformer_encoder_pass and cherry-pick from 48349 (#49383)

上级 a2d7e1d7
......@@ -31,6 +31,8 @@ namespace framework {
namespace ir {
namespace patterns {
static const std::unordered_set<std::string> FFN_ACTS{"relu", "gelu"};
PDNode* FusedMultiTransformerDecoderPattern::operator()() {
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("layer_norm", "X");
......@@ -359,11 +361,11 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("gelu");
->assert_is_ops_input(FFN_ACTS);
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
->assert_is_op_output("gelu")
auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS);
auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr())
->assert_is_ops_output(FFN_ACTS)
->AsIntermediate()
->assert_is_op_input("matmul_v2");
......@@ -397,8 +399,8 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var});
ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var});
......@@ -678,11 +680,11 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("gelu");
->assert_is_ops_input(FFN_ACTS);
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
->assert_is_op_output("gelu")
auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS);
auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr())
->assert_is_ops_output(FFN_ACTS)
->AsIntermediate()
->assert_is_op_input("matmul_v2");
......@@ -716,8 +718,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var});
ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var});
......@@ -1026,11 +1028,11 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("gelu");
->assert_is_ops_input(FFN_ACTS);
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
->assert_is_op_output("gelu")
auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS);
auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr())
->assert_is_ops_output(FFN_ACTS)
->AsIntermediate()
->assert_is_op_input("matmul_v2");
......@@ -1073,8 +1075,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var});
ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_c_allreduce_sum->LinksFrom({ffn_matmul1_out_var})
.LinksTo({ffn_c_allreduce_sum_out_var});
......@@ -1147,6 +1149,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_act,
Node* ffn_output) {
auto* matmul0_op = matmul0->Op();
auto* matmul_linear_op = matmul_linear->Op();
......@@ -1215,6 +1218,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
fused_multi_transformer_op_desc.SetAttr("act_method",
ffn_act->Op()->Type());
// output dropout attribute
fused_multi_transformer_op_desc.SetAttr("is_test", true);
......@@ -1455,9 +1460,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_eltadd0_out, ffn_eltadd0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_pattern);
ffn_act, ffn_act, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_pattern);
ffn_act_out, ffn_act_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_pattern);
......@@ -1578,6 +1583,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_act,
ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm,
......@@ -1644,8 +1650,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_gelu,
ffn_gelu_out,
ffn_act,
ffn_act_out,
ffn_eltadd_out});
// Remove unneeded nodes.
......@@ -1871,6 +1877,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_act,
Node* ffn_output) {
auto* matmul0_op = matmul0->Op();
auto* matmul_linear_op = matmul_linear->Op();
......@@ -1939,6 +1946,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
fused_multi_transformer_op_desc.SetAttr("act_method",
ffn_act->Op()->Type());
// output dropout attribute
fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
......@@ -2168,9 +2177,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern);
ffn_act, ffn_act, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern);
ffn_act_out, ffn_act_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
......@@ -2287,6 +2296,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_act,
ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm,
......@@ -2345,8 +2355,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_gelu,
ffn_gelu_out,
ffn_act,
ffn_act_out,
ffn_eltadd_out});
// Remove unneeded nodes.
......@@ -2592,6 +2602,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_act,
Node* ffn_output) {
auto* matmul_linear_op = matmul_linear->Op();
auto* ffn_matmul_1_op = ffn_matmul1->Op();
......@@ -2658,6 +2669,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
fused_multi_transformer_op_desc.SetAttr("act_method",
ffn_act->Op()->Type());
// output dropout attribute
fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
......@@ -2911,9 +2924,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern);
ffn_act, ffn_act, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern);
ffn_act_out, ffn_act_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
......@@ -3044,6 +3057,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_act,
ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm,
......@@ -3110,8 +3124,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_gelu,
ffn_gelu_out,
ffn_act,
ffn_act_out,
ffn_eltadd_out});
// Remove unneeded nodes.
......
......@@ -125,8 +125,8 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
......@@ -223,8 +223,8 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
......@@ -329,8 +329,8 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
......
......@@ -37,12 +37,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul1);
PATTERN_DECL_NODE(matmul2);
......@@ -73,6 +67,8 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(scale_q);
PATTERN_DECL_NODE(scale_q_out);
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
......@@ -98,29 +94,30 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// while loop
PATTERN_DECL_NODE(while0);
// post layer_norm
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
......@@ -131,6 +128,13 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
};
struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
......@@ -212,8 +216,8 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
......@@ -226,6 +230,129 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_output);
};
struct MultiDevicesFusedMultiTransformerEncoderPattern : public PatternBase {
MultiDevicesFusedMultiTransformerEncoderPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern,
name_scope,
"multi_devices_fused_multi_transformer_encoder") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(c_identity0);
PATTERN_DECL_NODE(c_identity0_out);
PATTERN_DECL_NODE(c_identity1);
PATTERN_DECL_NODE(c_identity1_out);
PATTERN_DECL_NODE(c_identity2);
PATTERN_DECL_NODE(c_identity2_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul1);
PATTERN_DECL_NODE(matmul2);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul1_w);
PATTERN_DECL_NODE(matmul2_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(matmul1_out);
PATTERN_DECL_NODE(matmul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(scale_q);
PATTERN_DECL_NODE(scale_q_out);
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(c_allreduce_sum);
PATTERN_DECL_NODE(c_allreduce_sum_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// post layer_norm
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_c_identity);
PATTERN_DECL_NODE(ffn_c_identity_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_c_allreduce_sum);
PATTERN_DECL_NODE(ffn_c_allreduce_sum_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
};
struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
: public PatternBase {
MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern(
......@@ -313,8 +440,8 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
......@@ -362,6 +489,23 @@ class FusedMultiTransformerEncoderFuseQKVPass : public FusePassBase {
Scope* scope) const;
};
class MultiDevicesFusedMultiTransformerEncoderPass : public FusePassBase {
public:
MultiDevicesFusedMultiTransformerEncoderPass();
virtual ~MultiDevicesFusedMultiTransformerEncoderPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{
"multi_devices_fused_multi_transformer_encoder"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
class MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
: public FusePassBase {
public:
......
......@@ -179,6 +179,7 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"fused_multi_transformer_decoder_pass",
"fused_multi_transformer_encoder_fuse_qkv_pass",
"fused_multi_transformer_decoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_encoder_pass",
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass",
"fuse_multi_transformer_layer_pass",
......@@ -228,6 +229,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"fused_multi_transformer_decoder_pass", //
"fused_multi_transformer_encoder_fuse_qkv_pass", //
"fused_multi_transformer_decoder_fuse_qkv_pass", //
"multi_devices_fused_multi_transformer_encoder_pass", //
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", //
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", //
"fuse_multi_transformer_layer_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册