未验证 提交 29eec2dd 编写于 作者: L lzy 提交者: GitHub

add multi_devices_fused_multi_transformer_encoder_pass and cherry-pick from 48349 (#49383)

上级 a2d7e1d7
......@@ -31,6 +31,8 @@ namespace framework {
namespace ir {
namespace patterns {
static const std::unordered_set<std::string> FFN_ACTS{"relu", "gelu"};
PDNode* FusedMultiTransformerDecoderPattern::operator()() {
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("layer_norm", "X");
......@@ -359,13 +361,13 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("gelu");
->assert_is_ops_input(FFN_ACTS);
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
->assert_is_op_output("gelu")
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS);
auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr())
->assert_is_ops_output(FFN_ACTS)
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_matmul1 =
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
......@@ -397,8 +399,8 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var});
ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var});
......@@ -678,13 +680,13 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("gelu");
->assert_is_ops_input(FFN_ACTS);
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
->assert_is_op_output("gelu")
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS);
auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr())
->assert_is_ops_output(FFN_ACTS)
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_matmul1 =
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
......@@ -716,8 +718,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var});
ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var});
......@@ -1026,13 +1028,13 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
->assert_is_op_output("elementwise_add")
->AsIntermediate()
->assert_is_op_input("gelu");
->assert_is_ops_input(FFN_ACTS);
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
->assert_is_op_output("gelu")
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS);
auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr())
->assert_is_ops_output(FFN_ACTS)
->AsIntermediate()
->assert_is_op_input("matmul_v2");
auto* ffn_matmul1 =
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
......@@ -1073,8 +1075,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.LinksTo({ffn_matmul0_out_var});
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
.LinksTo({ffn_eltadd0_out_var});
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var});
ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var})
.LinksTo({ffn_matmul1_out_var});
ffn_c_allreduce_sum->LinksFrom({ffn_matmul1_out_var})
.LinksTo({ffn_c_allreduce_sum_out_var});
......@@ -1147,6 +1149,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_act,
Node* ffn_output) {
auto* matmul0_op = matmul0->Op();
auto* matmul_linear_op = matmul_linear->Op();
......@@ -1215,6 +1218,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
fused_multi_transformer_op_desc.SetAttr("act_method",
ffn_act->Op()->Type());
// output dropout attribute
fused_multi_transformer_op_desc.SetAttr("is_test", true);
......@@ -1455,9 +1460,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_eltadd0_out, ffn_eltadd0_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_pattern);
ffn_act, ffn_act, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_pattern);
ffn_act_out, ffn_act_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_pattern);
......@@ -1578,6 +1583,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_act,
ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm,
......@@ -1644,8 +1650,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_gelu,
ffn_gelu_out,
ffn_act,
ffn_act_out,
ffn_eltadd_out});
// Remove unneeded nodes.
......@@ -1871,6 +1877,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_act,
Node* ffn_output) {
auto* matmul0_op = matmul0->Op();
auto* matmul_linear_op = matmul_linear->Op();
......@@ -1939,6 +1946,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
fused_multi_transformer_op_desc.SetAttr("act_method",
ffn_act->Op()->Type());
// output dropout attribute
fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
......@@ -2168,9 +2177,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern);
ffn_act, ffn_act, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern);
ffn_act_out, ffn_act_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
......@@ -2287,6 +2296,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_act,
ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm,
......@@ -2345,8 +2355,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_gelu,
ffn_gelu_out,
ffn_act,
ffn_act_out,
ffn_eltadd_out});
// Remove unneeded nodes.
......@@ -2592,6 +2602,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node* ffn_matmul1_w,
Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b,
Node* ffn_act,
Node* ffn_output) {
auto* matmul_linear_op = matmul_linear->Op();
auto* ffn_matmul_1_op = ffn_matmul1->Op();
......@@ -2658,6 +2669,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
fused_multi_transformer_op_desc.SetAttr("act_method",
ffn_act->Op()->Type());
// output dropout attribute
fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
......@@ -2911,9 +2924,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern);
ffn_act, ffn_act, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern);
ffn_act_out, ffn_act_out, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
......@@ -3044,6 +3057,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_matmul1_w,
ffn_eltadd0_b,
ffn_eltadd1_b,
ffn_act,
ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm,
......@@ -3110,8 +3124,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1,
ffn_eltadd0_out,
ffn_eltadd1_out,
ffn_gelu,
ffn_gelu_out,
ffn_act,
ffn_act_out,
ffn_eltadd_out});
// Remove unneeded nodes.
......
......@@ -125,8 +125,8 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
......@@ -223,8 +223,8 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
......@@ -329,8 +329,8 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
......
......@@ -37,12 +37,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul1);
PATTERN_DECL_NODE(matmul2);
......@@ -73,6 +67,8 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(scale_q);
PATTERN_DECL_NODE(scale_q_out);
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
......@@ -98,29 +94,30 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// while loop
PATTERN_DECL_NODE(while0);
// post layer_norm
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
......@@ -131,6 +128,13 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
};
struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
......@@ -212,11 +216,127 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
};
struct MultiDevicesFusedMultiTransformerEncoderPattern : public PatternBase {
MultiDevicesFusedMultiTransformerEncoderPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern,
name_scope,
"multi_devices_fused_multi_transformer_encoder") {}
PDNode* operator()();
// Q, K, V path
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(c_identity0);
PATTERN_DECL_NODE(c_identity0_out);
PATTERN_DECL_NODE(c_identity1);
PATTERN_DECL_NODE(c_identity1_out);
PATTERN_DECL_NODE(c_identity2);
PATTERN_DECL_NODE(c_identity2_out);
PATTERN_DECL_NODE(matmul0);
PATTERN_DECL_NODE(matmul1);
PATTERN_DECL_NODE(matmul2);
PATTERN_DECL_NODE(matmul0_w);
PATTERN_DECL_NODE(matmul1_w);
PATTERN_DECL_NODE(matmul2_w);
PATTERN_DECL_NODE(matmul0_out);
PATTERN_DECL_NODE(matmul1_out);
PATTERN_DECL_NODE(matmul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(scale_q);
PATTERN_DECL_NODE(scale_q_out);
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
// QK, V matmul
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_qkv_out);
// out linear
PATTERN_DECL_NODE(matmul_linear);
PATTERN_DECL_NODE(matmul_linear_w);
PATTERN_DECL_NODE(matmul_linear_out);
PATTERN_DECL_NODE(c_allreduce_sum);
PATTERN_DECL_NODE(c_allreduce_sum_out);
PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add
PATTERN_DECL_NODE(eltadd_out)
PATTERN_DECL_NODE(attention_output);
// post layer_norm
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_out);
// Feed Forward nodes
PATTERN_DECL_NODE(ffn_c_identity);
PATTERN_DECL_NODE(ffn_c_identity_out);
PATTERN_DECL_NODE(ffn_matmul0);
PATTERN_DECL_NODE(ffn_matmul0_w);
PATTERN_DECL_NODE(ffn_matmul0_out);
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
PATTERN_DECL_NODE(ffn_c_allreduce_sum);
PATTERN_DECL_NODE(ffn_c_allreduce_sum_out);
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out);
......@@ -224,6 +344,13 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
// output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out)
PATTERN_DECL_NODE(ffn_output);
PATTERN_DECL_NODE(ffn_layer_norm);
PATTERN_DECL_NODE(ffn_layer_norm_scale);
PATTERN_DECL_NODE(ffn_layer_norm_bias);
PATTERN_DECL_NODE(ffn_layer_norm_mean);
PATTERN_DECL_NODE(ffn_layer_norm_variance);
PATTERN_DECL_NODE(ffn_layer_norm_out);
};
struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
......@@ -313,8 +440,8 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd0_out);
PATTERN_DECL_NODE(ffn_gelu);
PATTERN_DECL_NODE(ffn_gelu_out);
PATTERN_DECL_NODE(ffn_act);
PATTERN_DECL_NODE(ffn_act_out);
PATTERN_DECL_NODE(ffn_matmul1);
PATTERN_DECL_NODE(ffn_matmul1_w);
PATTERN_DECL_NODE(ffn_matmul1_out);
......@@ -362,6 +489,23 @@ class FusedMultiTransformerEncoderFuseQKVPass : public FusePassBase {
Scope* scope) const;
};
class MultiDevicesFusedMultiTransformerEncoderPass : public FusePassBase {
public:
MultiDevicesFusedMultiTransformerEncoderPass();
virtual ~MultiDevicesFusedMultiTransformerEncoderPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{
"multi_devices_fused_multi_transformer_encoder"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
class MultiDevicesFusedMultiTransformerEncoderFuseQKVPass
: public FusePassBase {
public:
......
......@@ -56,8 +56,8 @@ Scope* CreateParamScope() {
// FFN: fc1 -> (gelu) -> fc2
AddVarToScope(param_scope, "ffn_weights0", {1024, 4096});
AddVarToScope(param_scope, "ffn_weights1", {4096, 1024});
AddVarToScope(param_scope, "ffn_bias_0", {4096});
AddVarToScope(param_scope, "ffn_bias_1", {1024});
AddVarToScope(param_scope, "ffn_bias0", {4096});
AddVarToScope(param_scope, "ffn_bias1", {1024});
return param_scope;
}
......@@ -65,10 +65,9 @@ Scope* CreateParamScope() {
TEST(FusedMultiTransformerEncoderPass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (x, ln_scale, ln_bias) layer_norm -> layer_norm_out
// (layer_norm_out, weights_0) matmul_v2 -> matmul_out0
// (layer_norm_out, weights_1) matmul_v2 -> matmul_out1
// (layer_norm_out, weights_2) matmul_v2 -> matmul_out2
// (x, weights_0) matmul_v2 -> matmul_out0
// (x, weights_1) matmul_v2 -> matmul_out1
// (x, weights_2) matmul_v2 -> matmul_out2
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (matmul_out1, bias_1) elementwise_add -> eltadd_1
// (matmul_out2, bias_2) elementwise_add -> eltadd_2
......@@ -78,7 +77,8 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// (reshape_0) transpose2 -> transpose_0
// (reshape_1) transpose2 -> transpose_1
// (reshape_2) transpose2 -> transpose_2
// (transpose_0, transpose_1) matmul -> matmul_qk
// (transpose_0) scale -> scale_0
// (scale_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
......@@ -86,35 +86,28 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_out) elementwise_add -> attention_out
// (eltadd_linear) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
// (attention_out, scale, bias) layer_norm -> layer_norm_out
// (layer_norm_out, ffn_matmul0_w) matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
//
// (transpose_1, transpose_2) while -> decoder block
// (layer_norm_out, ffn_eltadd1) elementwise_add -> ffn_output
// (ffn_output, scale, bias) layer_norm -> ffn_layer_norm_out
Layers layers;
// MHA: pre LayerNorm
auto* x = layers.data("x", {1, 128, 1024});
auto* ln_scale = layers.data("ln_scale", {1024}, true);
auto* ln_bias = layers.data("ln_bias", {1024}, true);
auto* ln_out = layers.layer_norm(x, ln_scale, ln_bias)[0];
// MHA: QKV fc
auto* weights_0 = layers.data("weights0", {1024, 1024}, true);
auto* weights_1 = layers.data("weights1", {1024, 1024}, true);
auto* weights_2 = layers.data("weights2", {1024, 1024}, true);
auto* matmul_out_0 =
layers.matmul_v2(ln_out, weights_0, nullptr, false, true);
auto* matmul_out_1 =
layers.matmul_v2(ln_out, weights_1, nullptr, false, true);
auto* matmul_out_2 =
layers.matmul_v2(ln_out, weights_2, nullptr, false, true);
auto* matmul_out_0 = layers.matmul_v2(x, weights_0, nullptr, false, false);
auto* matmul_out_1 = layers.matmul_v2(x, weights_1, nullptr, false, false);
auto* matmul_out_2 = layers.matmul_v2(x, weights_2, nullptr, false, false);
auto* b0 = layers.data("bias_0", {1024}, true);
auto* b1 = layers.data("bias_1", {1024}, true);
......@@ -136,14 +129,13 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto* transpose_1 = layers.transpose2(reshape_1, axis, true);
auto* transpose_2 = layers.transpose2(reshape_2, axis, true);
// Link to decoder while block
layers.while_loop({transpose_1, transpose_2});
// q scale
auto* scale_q = layers.scale(transpose_0, 0.125, 0, false);
// MHA: QK matmul
auto* matmul_qk =
layers.matmul(transpose_0, transpose_1, nullptr, false, true);
layers.matmul_v2(scale_q, transpose_1, nullptr, false, true);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* bqk = layers.data("biasqk", {1, 1, 1, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk, nullptr, -1);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
......@@ -155,19 +147,18 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// MHA: out Linear
auto* weights_l = layers.data("weights_l", {1024, 1024}, true);
auto* bias_l = layers.data("weightsl", {1024, 1024}, true);
auto* bias_l = layers.data("bias_l", {1024}, true);
auto* linear_matmut_out =
layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, true);
layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, false);
auto* linear_eltadd_out =
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
// FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true);
auto* ffn_ln_out =
layers.layer_norm(attention_out, ffn_ln_scale, ffn_ln_bias)[0];
// post LayerNorm
auto* ln_scale = layers.data("ln_scale", {1024}, true);
auto* ln_bias = layers.data("ln_bias", {1024}, true);
auto* ln_out = layers.layer_norm(attention_out, ln_scale, ln_bias)[0];
// FFN: fc1 -> gelu -> fc2
auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true);
......@@ -175,7 +166,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true);
auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true);
auto* ffn_matmul0_out =
layers.matmul_v2(ffn_ln_out, ffn_weights0, nullptr, false, true);
layers.matmul_v2(ln_out, ffn_weights0, nullptr, false, true);
auto* ffn_eltadd0_out =
layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2);
auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out);
......@@ -184,7 +175,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
layers.elementwise_add(attention_out, ffn_eltadd1_out);
auto* ffn_out = layers.elementwise_add(ln_out, ffn_eltadd1_out);
// FFN: post LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true);
layers.layer_norm(ffn_out, ffn_ln_scale, ffn_ln_bias)[0];
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
......@@ -203,12 +199,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(num_nodes_before,
num_nodes_after + 56,
num_nodes_after + 58,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_pass, The "
"node num in graph "
"should be %d, but the result is %d",
num_nodes_before - 56,
num_nodes_before - 58,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
......@@ -225,6 +221,183 @@ TEST(FusedMultiTransformerEncoderPass, pass_op_version_check) {
.IsPassCompatible("fused_multi_transformer_encoder_pass"));
}
TEST(MultiDevicesFusedMultiTransformerEncoderPass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (x) c_identity -> c_identity0_out
// (x) c_identity -> c_identity1_out
// (x) c_identity -> c_identity2_out
// (c_identity0_out, weights_0) matmul_v2 -> matmul_out0
// (c_identity1_out, weights_1) matmul_v2 -> matmul_out1
// (c_identity2_out, weights_2) matmul_v2 -> matmul_out2
// (matmul_out0, bias_0) elementwise_add -> eltadd_0
// (matmul_out1, bias_1) elementwise_add -> eltadd_1
// (matmul_out2, bias_2) elementwise_add -> eltadd_2
// (eltadd_0) reshape2 -> reshape_0
// (eltadd_1) reshape2 -> reshape_1
// (eltadd_2) reshape2 -> reshape_2
// (reshape_0) transpose2 -> transpose_0
// (reshape_1) transpose2 -> transpose_1
// (reshape_2) transpose2 -> transpose_2
// (transpose_0) scale -> scale_0
// (scale_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) c_all_reduce -> c_all_reduce_out
// (c_all_reduce_out) elementwise_add -> eltadd_linear
// (eltadd_linear) elementwise_add -> attention_out
//
// (attention_out, scale, bias) layer_norm -> layer_norm_out
// (layer_norm_out) c_identity -> ffn_c_identity_out
// (ffn_c_identity_out, ffn_matmul0_w)matmul_v2 -> ffn_matmul0
// (ffn_matmul0, ffn_bias0) elementwise_add -> ffn_eltadd0
// (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1) c_all_reduce -> ffn_c_all_reduce_out
// (ffn_c_all_reduce_out, ffn_bias1)elementwise_add -> ffn_eltadd1
// (layer_norm_out, ffn_eltadd1) elementwise_add -> ffn_output
// (ffn_output, scale, bias) layer_norm -> ffn_layer_norm_out
Layers layers;
// MHA: pre LayerNorm
auto* x = layers.data("x", {1, 128, 1024});
auto* c_identity0_out = layers.c_identity(x);
auto* c_identity1_out = layers.c_identity(x);
auto* c_identity2_out = layers.c_identity(x);
// MHA: QKV fc
auto* weights_0 = layers.data("weights0", {1024, 1024}, true);
auto* weights_1 = layers.data("weights1", {1024, 1024}, true);
auto* weights_2 = layers.data("weights2", {1024, 1024}, true);
auto* matmul_out_0 =
layers.matmul_v2(c_identity0_out, weights_0, nullptr, false, false);
auto* matmul_out_1 =
layers.matmul_v2(c_identity1_out, weights_1, nullptr, false, false);
auto* matmul_out_2 =
layers.matmul_v2(c_identity2_out, weights_2, nullptr, false, false);
auto* b0 = layers.data("bias_0", {1024}, true);
auto* b1 = layers.data("bias_1", {1024}, true);
auto* b2 = layers.data("bias_2", {1024}, true);
auto* elementwise_out_0 =
layers.elementwise_add(matmul_out_0, b0, nullptr, 2);
auto* elementwise_out_1 =
layers.elementwise_add(matmul_out_1, b1, nullptr, 2);
auto* elementwise_out_2 =
layers.elementwise_add(matmul_out_2, b2, nullptr, 2);
std::vector<int> shape = {1, 128, 16, 64};
auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true);
auto* reshape_1 = layers.reshape2(elementwise_out_1, shape, true);
auto* reshape_2 = layers.reshape2(elementwise_out_2, shape, true);
std::vector<int> axis = {0, 2, 1, 3};
auto* transpose_0 = layers.transpose2(reshape_0, axis, true);
auto* transpose_1 = layers.transpose2(reshape_1, axis, true);
auto* transpose_2 = layers.transpose2(reshape_2, axis, true);
// q scale
auto* scale_q = layers.scale(transpose_0, 0.125, 0, false);
// MHA: QK matmul
auto* matmul_qk =
layers.matmul_v2(scale_q, transpose_1, nullptr, false, true);
auto* bqk = layers.data("biasqk", {1, 1, 1, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk, nullptr, -1);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
// MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(softmax_qk, transpose_2);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
// MHA: out Linear
auto* weights_l = layers.data("weights_l", {1024, 1024}, true);
auto* bias_l = layers.data("bias_l", {1024}, true);
auto* linear_matmut_out =
layers.matmul_v2(reshape_qkv_out, weights_l, nullptr, false, false);
auto* c_allreduce_out = layers.c_allreduce_sum(linear_matmut_out);
auto* linear_eltadd_out =
layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2);
auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
// post LayerNorm
auto* ln_scale = layers.data("ln_scale", {1024}, true);
auto* ln_bias = layers.data("ln_bias", {1024}, true);
auto* ln_out = layers.layer_norm(attention_out, ln_scale, ln_bias)[0];
auto* ffn_c_identity_out = layers.c_identity(ln_out);
// FFN: fc1 -> gelu -> fc2
auto* ffn_weights0 = layers.data("ffn_weights0", {1024, 4096}, true);
auto* ffn_weights1 = layers.data("ffn_weights1", {4096, 1024}, true);
auto* ffn_bias0 = layers.data("ffn_bias0", {4096}, true);
auto* ffn_bias1 = layers.data("ffn_bias1", {1024}, true);
auto* ffn_matmul0_out =
layers.matmul_v2(ffn_c_identity_out, ffn_weights0, nullptr, false, false);
auto* ffn_eltadd0_out =
layers.elementwise_add(ffn_matmul0_out, ffn_bias0, nullptr, 2);
auto* ffn_gelu_out = layers.gelu(ffn_eltadd0_out);
auto* ffn_matmul1_out =
layers.matmul_v2(ffn_gelu_out, ffn_weights1, nullptr, false, false);
auto* ffn_allreduce_out = layers.c_allreduce_sum(ffn_matmul1_out);
auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_allreduce_out, ffn_bias1, nullptr, 2);
auto* ffn_out = layers.elementwise_add(ln_out, ffn_eltadd1_out);
// FFN: post LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true);
layers.layer_norm(ffn_out, ffn_ln_scale, ffn_ln_bias)[0];
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
graph->Set("enable_int8", new bool(false));
auto pass = PassRegistry::Instance().Get(
"multi_devices_fused_multi_transformer_encoder_pass");
if (pass.get() == nullptr)
LOG(INFO)
<< "get multi_devices_fused_multi_transformer_encoder_pass failed";
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
VLOG(3) << DebugString(graph);
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(num_nodes_before,
num_nodes_after + 70,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_pass, The "
"node num in graph "
"should be %d, but the result is %d",
num_nodes_before - 70,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder pass, "
"there should be one fused_multi_transformer op, "
"but the result is %d",
num_fused_nodes_after));
}
TEST(MultiDevicesFusedMultiTransformerEncoderPass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible(
"multi_devices_fused_multi_transformer_encoder_pass"));
}
TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
// inputs operator output
// --------------------------------------------------------------------
......@@ -292,7 +465,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
auto* matmul_qk = layers.matmul_v2(split_q, split_k, nullptr, false, true);
auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* bqk = layers.data("biasqk", {1, 1, 1, 128}, true);
auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
......@@ -447,7 +620,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
auto* matmul_qk = layers.matmul_v2(split_q, split_k, nullptr, false, true);
auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false);
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* bqk = layers.data("biasqk", {1, 1, 1, 128}, true);
auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);
......@@ -542,4 +715,5 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass,
USE_PASS(fused_multi_transformer_encoder_pass);
USE_PASS(fused_multi_transformer_encoder_fuse_qkv_pass);
USE_PASS(multi_devices_fused_multi_transformer_encoder_pass);
USE_PASS(multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass);
......@@ -179,6 +179,7 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"fused_multi_transformer_decoder_pass",
"fused_multi_transformer_encoder_fuse_qkv_pass",
"fused_multi_transformer_decoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_encoder_pass",
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass",
"fuse_multi_transformer_layer_pass",
......@@ -228,6 +229,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"fused_multi_transformer_decoder_pass", //
"fused_multi_transformer_encoder_fuse_qkv_pass", //
"fused_multi_transformer_decoder_fuse_qkv_pass", //
"multi_devices_fused_multi_transformer_encoder_pass", //
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", //
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", //
"fuse_multi_transformer_layer_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册