未验证 提交 ba4fbe71 编写于 作者: K Kaipeng Deng 提交者: GitHub

[cherry pick] fix memory copy in prepare_data of FusedMultiTransformer pass (#47308)

* fix memory copy in prepare_data. test=develop

* add cache_kv fp16 support. test=develop

* fit for simplify_with_basic_ops_pass. test=develop
上级 7a1cf277
...@@ -237,15 +237,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { ...@@ -237,15 +237,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax") ->assert_is_op_output("softmax")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("dropout"); ->assert_is_op_input("matmul_v2", "X");
auto* dropout_qk =
pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout");
auto* dropout_qk_out_var =
pattern->NewNode(dropout_qk_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv
// QK path Linsk // QK path Linsk
matmul_qk->LinksFrom({transpose2_0_out_var, concat_0_out_var}) matmul_qk->LinksFrom({transpose2_0_out_var, concat_0_out_var})
...@@ -253,7 +245,6 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { ...@@ -253,7 +245,6 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var}); .LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var});
// QKV path Nodes // QKV path Nodes
auto* matmul_qkv = auto* matmul_qkv =
...@@ -294,14 +285,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { ...@@ -294,14 +285,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("dropout"); ->assert_is_op_input("elementwise_add");
auto* dropout_linear =
pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout");
auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out = auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
...@@ -310,7 +294,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { ...@@ -310,7 +294,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
->AsIntermediate(); ->AsIntermediate();
// QKV path Links // QKV path Links
matmul_qkv->LinksFrom({dropout_qk_out_var, concat_1_out_var}) matmul_qkv->LinksFrom({softmax_qk_out_var, concat_1_out_var})
.LinksTo({matmul_qkv_out_var}); .LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var}) transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var}); .LinksTo({transpose2_qkv_out_var});
...@@ -320,9 +304,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { ...@@ -320,9 +304,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
.LinksTo({matmul_linear_out_var}); .LinksTo({matmul_linear_out_var});
eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var}) eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var}); .LinksTo({eltadd_linear_out_var});
dropout_linear->LinksFrom({eltadd_linear_out_var}) eltadd_out->LinksFrom({input0, eltadd_linear_out_var})
.LinksTo({dropout_linear_out_var});
eltadd_out->LinksFrom({input0, dropout_linear_out_var})
.LinksTo({attention_output}); .LinksTo({attention_output});
// Feed Forward LayerNorm Nodes // Feed Forward LayerNorm Nodes
...@@ -358,7 +340,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { ...@@ -358,7 +340,7 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
ffn_layer_norm_mean_var, ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var}); ffn_layer_norm_variance_var});
// Feed Forward fc1 -> gelu -> fc2 -> dropout // Feed Forward fc1 -> gelu -> fc2
auto* ffn_matmul0 = auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
...@@ -403,13 +385,6 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { ...@@ -403,13 +385,6 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("dropout");
auto* ffn_dropout =
pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout");
auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add"); ->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out = auto* ffn_eltadd_out =
...@@ -427,9 +402,8 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() { ...@@ -427,9 +402,8 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
.LinksTo({ffn_matmul1_out_var}); .LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var}); .LinksTo({ffn_eltadd1_out_var});
ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var}) ffn_eltadd_out->LinksFrom({attention_output, ffn_eltadd1_out_var})
.LinksTo({ffn_output}); .LinksTo({ffn_output});
return ffn_output; return ffn_output;
...@@ -575,15 +549,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -575,15 +549,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax") ->assert_is_op_output("softmax")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("dropout"); ->assert_is_op_input("matmul_v2", "X");
auto* dropout_qk =
pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout");
auto* dropout_qk_out_var =
pattern->NewNode(dropout_qk_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv
// QK path Linsk // QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var}) matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var})
...@@ -591,7 +557,6 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -591,7 +557,6 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var}); .LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var});
// QKV path Nodes // QKV path Nodes
auto* matmul_qkv = auto* matmul_qkv =
...@@ -632,14 +597,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -632,14 +597,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("dropout"); ->assert_is_op_input("elementwise_add");
auto* dropout_linear =
pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout");
auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out = auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
...@@ -648,7 +606,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -648,7 +606,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
->AsIntermediate(); ->AsIntermediate();
// QKV path Links // QKV path Links
matmul_qkv->LinksFrom({dropout_qk_out_var, concat_v_out_var}) matmul_qkv->LinksFrom({softmax_qk_out_var, concat_v_out_var})
.LinksTo({matmul_qkv_out_var}); .LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var}) transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var}); .LinksTo({transpose2_qkv_out_var});
...@@ -658,9 +616,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -658,9 +616,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.LinksTo({matmul_linear_out_var}); .LinksTo({matmul_linear_out_var});
eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var}) eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var}); .LinksTo({eltadd_linear_out_var});
dropout_linear->LinksFrom({eltadd_linear_out_var}) eltadd_out->LinksFrom({input0, eltadd_linear_out_var})
.LinksTo({dropout_linear_out_var});
eltadd_out->LinksFrom({input0, dropout_linear_out_var})
.LinksTo({attention_output}); .LinksTo({attention_output});
// Feed Forward LayerNorm Nodes // Feed Forward LayerNorm Nodes
...@@ -696,7 +652,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -696,7 +652,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
ffn_layer_norm_mean_var, ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var}); ffn_layer_norm_variance_var});
// Feed Forward fc1 -> gelu -> fc2 -> dropout // Feed Forward fc1 -> gelu -> fc2
auto* ffn_matmul0 = auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
...@@ -741,13 +697,6 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -741,13 +697,6 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("dropout");
auto* ffn_dropout =
pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout");
auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add"); ->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out = auto* ffn_eltadd_out =
...@@ -765,9 +714,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -765,9 +714,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.LinksTo({ffn_matmul1_out_var}); .LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var}); .LinksTo({ffn_eltadd1_out_var});
ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var}) ffn_eltadd_out->LinksFrom({attention_output, ffn_eltadd1_out_var})
.LinksTo({ffn_output}); .LinksTo({ffn_output});
return ffn_output; return ffn_output;
...@@ -922,15 +870,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -922,15 +870,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax") ->assert_is_op_output("softmax")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("dropout"); ->assert_is_op_input("matmul_v2", "X");
auto* dropout_qk =
pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout");
auto* dropout_qk_out_var =
pattern->NewNode(dropout_qk_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv
// QK path Linsk // QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var}) matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var})
...@@ -938,7 +878,6 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -938,7 +878,6 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var}); .LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var});
// QKV path Nodes // QKV path Nodes
auto* matmul_qkv = auto* matmul_qkv =
...@@ -987,14 +926,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -987,14 +926,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("dropout"); ->assert_is_op_input("elementwise_add");
auto* dropout_linear =
pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout");
auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out = auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
...@@ -1003,7 +935,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -1003,7 +935,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
->AsIntermediate(); ->AsIntermediate();
// QKV path Links // QKV path Links
matmul_qkv->LinksFrom({dropout_qk_out_var, concat_v_out_var}) matmul_qkv->LinksFrom({softmax_qk_out_var, concat_v_out_var})
.LinksTo({matmul_qkv_out_var}); .LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var}) transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var}); .LinksTo({transpose2_qkv_out_var});
...@@ -1015,9 +947,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -1015,9 +947,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.LinksTo({c_allreduce_sum_out_var}); .LinksTo({c_allreduce_sum_out_var});
eltadd_linear->LinksFrom({c_allreduce_sum_out_var, eltadd_linear_b_var}) eltadd_linear->LinksFrom({c_allreduce_sum_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var}); .LinksTo({eltadd_linear_out_var});
dropout_linear->LinksFrom({eltadd_linear_out_var}) eltadd_out->LinksFrom({input0, eltadd_linear_out_var})
.LinksTo({dropout_linear_out_var});
eltadd_out->LinksFrom({input0, dropout_linear_out_var})
.LinksTo({attention_output}); .LinksTo({attention_output});
// Feed Forward LayerNorm Nodes // Feed Forward LayerNorm Nodes
...@@ -1063,7 +993,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -1063,7 +993,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
ffn_c_identity->LinksFrom({ffn_layer_norm_out_var}) ffn_c_identity->LinksFrom({ffn_layer_norm_out_var})
.LinksTo({ffn_c_identity_out_var}); .LinksTo({ffn_c_identity_out_var});
// Feed Forward fc1 -> gelu -> fc2 -> dropout // Feed Forward fc1 -> gelu -> fc2
auto* ffn_matmul0 = auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
...@@ -1117,13 +1047,6 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -1117,13 +1047,6 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("dropout");
auto* ffn_dropout =
pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout");
auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add"); ->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out = auto* ffn_eltadd_out =
...@@ -1143,9 +1066,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { ...@@ -1143,9 +1066,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
.LinksTo({ffn_c_allreduce_sum_out_var}); .LinksTo({ffn_c_allreduce_sum_out_var});
ffn_eltadd1->LinksFrom({ffn_c_allreduce_sum_out_var, ffn_eltadd1_b_var}) ffn_eltadd1->LinksFrom({ffn_c_allreduce_sum_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var}); .LinksTo({ffn_eltadd1_out_var});
ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var}) ffn_eltadd_out->LinksFrom({attention_output, ffn_eltadd1_out_var})
.LinksTo({ffn_output}); .LinksTo({ffn_output});
return ffn_output; return ffn_output;
...@@ -1180,11 +1102,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1180,11 +1102,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
Node* transpose2_1_out, Node* transpose2_1_out,
Node* transpose2_2_out, Node* transpose2_2_out,
Node* eltadd_qk_b, Node* eltadd_qk_b,
Node* dropout_qk,
Node* reshape2_0, Node* reshape2_0,
Node* matmul_linear_w, Node* matmul_linear_w,
Node* eltadd_linear_b, Node* eltadd_linear_b,
Node* dropout_linear,
Node* ffn_layer_norm, Node* ffn_layer_norm,
Node* ffn_layer_norm_scale, Node* ffn_layer_norm_scale,
Node* ffn_layer_norm_bias, Node* ffn_layer_norm_bias,
...@@ -1194,7 +1114,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1194,7 +1114,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
Node* ffn_matmul1_w, Node* ffn_matmul1_w,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
Node* ffn_dropout,
Node* ffn_output) { Node* ffn_output) {
// Calc index of transformer layer by LayerNorm Scale name // Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes: // This calculation assumes:
...@@ -1287,14 +1206,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1287,14 +1206,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
"epsilon", layer_norm->Op()->GetAttr("epsilon")); "epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute // output dropout attribute
auto* dropout_op = dropout_linear->Op(); fused_multi_transformer_op_desc.SetAttr("is_test", true);
fused_multi_transformer_op_desc.SetAttr( fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
"dropout_rate", dropout_op->GetAttr("dropout_prob"));
fused_multi_transformer_op_desc.SetAttr("is_test",
dropout_op->GetAttr("is_test"));
fused_multi_transformer_op_desc.SetAttr(
"dropout_implementation",
dropout_op->GetAttr("dropout_implementation"));
auto* fused_multi_transformer = auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc); graph->CreateOpNode(&fused_multi_transformer_op_desc);
...@@ -1313,6 +1226,15 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1313,6 +1226,15 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
IR_NODE_LINK_TO(slice_op, slice_out); IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer) IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_matmul1_w, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_eltadd1_b, fused_multi_transformer);
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); IR_NODE_LINK_TO(fused_multi_transformer, ffn_output);
}; };
...@@ -1451,11 +1373,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1451,11 +1373,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_out, ffn_eltadd1_out, fused_multi_transformer_pattern); ffn_eltadd1_out, ffn_eltadd1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout, ffn_dropout, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout_out, ffn_dropout_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_pattern) ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
...@@ -1499,10 +1416,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1499,10 +1416,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
softmax_qk, softmax_qk, fused_multi_transformer_pattern); softmax_qk, softmax_qk, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk_out, softmax_qk_out, fused_multi_transformer_pattern); softmax_qk_out, softmax_qk_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk, dropout_qk, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk_out, dropout_qk_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_pattern); matmul_qkv, matmul_qkv, fused_multi_transformer_pattern);
...@@ -1531,10 +1444,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1531,10 +1444,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
eltadd_linear_b, eltadd_linear_b, fused_multi_transformer_pattern) eltadd_linear_b, eltadd_linear_b, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_pattern) eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
dropout_linear, dropout_linear, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
dropout_linear_out, dropout_linear_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_pattern) eltadd_out, eltadd_out, fused_multi_transformer_pattern)
...@@ -1554,11 +1463,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1554,11 +1463,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
transpose2_1_out, transpose2_1_out,
transpose2_2_out, transpose2_2_out,
eltadd_qk_b, eltadd_qk_b,
dropout_qk,
reshape2_0, reshape2_0,
matmul_linear_w, matmul_linear_w,
eltadd_linear_b, eltadd_linear_b,
dropout_linear,
ffn_layer_norm, ffn_layer_norm,
ffn_layer_norm_scale, ffn_layer_norm_scale,
ffn_layer_norm_bias, ffn_layer_norm_bias,
...@@ -1568,12 +1475,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1568,12 +1475,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_matmul1_w, ffn_matmul1_w,
ffn_eltadd0_b, ffn_eltadd0_b,
ffn_eltadd1_b, ffn_eltadd1_b,
ffn_dropout,
ffn_output); ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm, std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean, layer_norm_mean,
layer_norm_variance, layer_norm_variance,
layer_norm_out, layer_norm_out,
...@@ -1613,8 +1517,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1613,8 +1517,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
eltadd_qk_out, eltadd_qk_out,
softmax_qk, softmax_qk,
softmax_qk_out, softmax_qk_out,
dropout_qk,
dropout_qk_out,
transpose2_qkv, transpose2_qkv,
transpose2_qkv_out, transpose2_qkv_out,
matmul_qkv, matmul_qkv,
...@@ -1623,17 +1525,11 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1623,17 +1525,11 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
transpose2_qkv, transpose2_qkv,
transpose2_qkv_out, transpose2_qkv_out,
matmul_linear, matmul_linear,
matmul_linear_w,
matmul_linear_out, matmul_linear_out,
eltadd_linear, eltadd_linear,
eltadd_linear_b,
eltadd_linear_out, eltadd_linear_out,
dropout_linear,
dropout_linear_out,
eltadd_out, eltadd_out,
ffn_layer_norm, ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean, ffn_layer_norm_mean,
ffn_layer_norm_variance, ffn_layer_norm_variance,
ffn_layer_norm_out, ffn_layer_norm_out,
...@@ -1647,8 +1543,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ...@@ -1647,8 +1543,6 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
ffn_eltadd1_out, ffn_eltadd1_out,
ffn_gelu, ffn_gelu,
ffn_gelu_out, ffn_gelu_out,
ffn_dropout,
ffn_dropout_out,
ffn_eltadd_out}); ffn_eltadd_out});
// Remove unneeded nodes. // Remove unneeded nodes.
...@@ -1850,11 +1744,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -1850,11 +1744,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node* matmul0_w, Node* matmul0_w,
Node* eltadd0_b, Node* eltadd0_b,
Node* eltadd_qk_b, Node* eltadd_qk_b,
Node* dropout_qk,
Node* reshape2_0, Node* reshape2_0,
Node* matmul_linear_w, Node* matmul_linear_w,
Node* eltadd_linear_b, Node* eltadd_linear_b,
Node* dropout_linear,
Node* ffn_layer_norm, Node* ffn_layer_norm,
Node* ffn_layer_norm_scale, Node* ffn_layer_norm_scale,
Node* ffn_layer_norm_bias, Node* ffn_layer_norm_bias,
...@@ -1864,7 +1756,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -1864,7 +1756,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node* ffn_matmul1_w, Node* ffn_matmul1_w,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
Node* ffn_dropout,
Node* ffn_output) { Node* ffn_output) {
// Calc index of transformer layer by LayerNorm Scale name // Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes: // This calculation assumes:
...@@ -1957,17 +1848,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -1957,17 +1848,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
"epsilon", layer_norm->Op()->GetAttr("epsilon")); "epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute // output dropout attribute
auto* dropout_op = dropout_linear->Op(); fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
fused_multi_transformer_op_desc.SetAttr( fused_multi_transformer_op_desc.SetAttr("is_test", true);
"dropout_rate", dropout_op->GetAttr("dropout_prob"));
fused_multi_transformer_op_desc.SetAttr("is_test",
dropout_op->GetAttr("is_test"));
fused_multi_transformer_op_desc.SetAttr(
"dropout_implementation",
dropout_op->GetAttr("dropout_implementation"));
// fused_multi_transformer_op_desc.SetAttr("act_method", {"gelu"});
// fused_multi_transformer_op_desc.SetAttr("trans_qkvw", {true});
auto* fused_multi_transformer = auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc); graph->CreateOpNode(&fused_multi_transformer_op_desc);
...@@ -1986,6 +1868,15 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -1986,6 +1868,15 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
IR_NODE_LINK_TO(slice_op, slice_out); IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer) IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_matmul1_w, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_eltadd1_b, fused_multi_transformer);
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); IR_NODE_LINK_TO(fused_multi_transformer, ffn_output);
}; };
...@@ -2116,12 +2007,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2116,12 +2007,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1_out, ffn_eltadd1_out,
fused_multi_transformer_fuse_qkv_pattern); fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out,
ffn_dropout_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out, GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out,
ffn_eltadd_out, ffn_eltadd_out,
fused_multi_transformer_fuse_qkv_pattern) fused_multi_transformer_fuse_qkv_pattern)
...@@ -2153,11 +2038,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2153,11 +2038,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out,
softmax_qk_out, softmax_qk_out,
fused_multi_transformer_fuse_qkv_pattern); fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out,
dropout_qk_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern); matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern);
...@@ -2193,12 +2073,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2193,12 +2073,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out, GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out,
eltadd_linear_out, eltadd_linear_out,
fused_multi_transformer_fuse_qkv_pattern) fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear,
dropout_linear,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out,
dropout_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern) eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern)
...@@ -2212,11 +2086,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2212,11 +2086,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
matmul0_w, matmul0_w,
eltadd0_b, eltadd0_b,
eltadd_qk_b, eltadd_qk_b,
dropout_qk,
reshape2_0, reshape2_0,
matmul_linear_w, matmul_linear_w,
eltadd_linear_b, eltadd_linear_b,
dropout_linear,
ffn_layer_norm, ffn_layer_norm,
ffn_layer_norm_scale, ffn_layer_norm_scale,
ffn_layer_norm_bias, ffn_layer_norm_bias,
...@@ -2226,12 +2098,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2226,12 +2098,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_matmul1_w, ffn_matmul1_w,
ffn_eltadd0_b, ffn_eltadd0_b,
ffn_eltadd1_b, ffn_eltadd1_b,
ffn_dropout,
ffn_output); ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm, std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean, layer_norm_mean,
layer_norm_variance, layer_norm_variance,
layer_norm_out, layer_norm_out,
...@@ -2261,8 +2130,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2261,8 +2130,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
eltadd_qk_out, eltadd_qk_out,
softmax_qk, softmax_qk,
softmax_qk_out, softmax_qk_out,
dropout_qk,
dropout_qk_out,
transpose2_qkv, transpose2_qkv,
transpose2_qkv_out, transpose2_qkv_out,
matmul_qkv, matmul_qkv,
...@@ -2271,17 +2138,11 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2271,17 +2138,11 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
transpose2_qkv, transpose2_qkv,
transpose2_qkv_out, transpose2_qkv_out,
matmul_linear, matmul_linear,
matmul_linear_w,
matmul_linear_out, matmul_linear_out,
eltadd_linear, eltadd_linear,
eltadd_linear_b,
eltadd_linear_out, eltadd_linear_out,
dropout_linear,
dropout_linear_out,
eltadd_out, eltadd_out,
ffn_layer_norm, ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean, ffn_layer_norm_mean,
ffn_layer_norm_variance, ffn_layer_norm_variance,
ffn_layer_norm_out, ffn_layer_norm_out,
...@@ -2295,8 +2156,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2295,8 +2156,6 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1_out, ffn_eltadd1_out,
ffn_gelu, ffn_gelu,
ffn_gelu_out, ffn_gelu_out,
ffn_dropout,
ffn_dropout_out,
ffn_eltadd_out}); ffn_eltadd_out});
// Remove unneeded nodes. // Remove unneeded nodes.
...@@ -2500,11 +2359,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2500,11 +2359,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node* matmul0_w, Node* matmul0_w,
Node* eltadd0_b, Node* eltadd0_b,
Node* eltadd_qk_b, Node* eltadd_qk_b,
Node* dropout_qk,
Node* reshape2_0, Node* reshape2_0,
Node* matmul_linear_w, Node* matmul_linear_w,
Node* eltadd_linear_b, Node* eltadd_linear_b,
Node* dropout_linear,
Node* ffn_layer_norm, Node* ffn_layer_norm,
Node* ffn_layer_norm_scale, Node* ffn_layer_norm_scale,
Node* ffn_layer_norm_bias, Node* ffn_layer_norm_bias,
...@@ -2514,7 +2371,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2514,7 +2371,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
Node* ffn_matmul1_w, Node* ffn_matmul1_w,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
Node* ffn_dropout,
Node* ffn_output) { Node* ffn_output) {
// Calc index of transformer layer by LayerNorm Scale name // Calc index of transformer layer by LayerNorm Scale name
// This calculation assumes: // This calculation assumes:
...@@ -2607,23 +2463,14 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2607,23 +2463,14 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
"epsilon", layer_norm->Op()->GetAttr("epsilon")); "epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute // output dropout attribute
auto* dropout_op = dropout_linear->Op(); fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
fused_multi_transformer_op_desc.SetAttr( fused_multi_transformer_op_desc.SetAttr("is_test", true);
"dropout_rate", dropout_op->GetAttr("dropout_prob"));
fused_multi_transformer_op_desc.SetAttr("is_test",
dropout_op->GetAttr("is_test"));
fused_multi_transformer_op_desc.SetAttr(
"dropout_implementation",
dropout_op->GetAttr("dropout_implementation"));
// parallel ring id // parallel ring id
auto* c_identity_op = c_identity->Op(); auto* c_identity_op = c_identity->Op();
fused_multi_transformer_op_desc.SetAttr("ring_id", fused_multi_transformer_op_desc.SetAttr("ring_id",
c_identity_op->GetAttr("ring_id")); c_identity_op->GetAttr("ring_id"));
// fused_multi_transformer_op_desc.SetAttr("act_method", {"gelu"});
// fused_multi_transformer_op_desc.SetAttr("trans_qkvw", {true});
auto* fused_multi_transformer = auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc); graph->CreateOpNode(&fused_multi_transformer_op_desc);
IR_NODE_LINK_TO(input0, fused_multi_transformer); IR_NODE_LINK_TO(input0, fused_multi_transformer);
...@@ -2641,6 +2488,15 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2641,6 +2488,15 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
IR_NODE_LINK_TO(slice_op, slice_out); IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer) IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_matmul1_w, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_eltadd1_b, fused_multi_transformer);
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); IR_NODE_LINK_TO(fused_multi_transformer, ffn_output);
}; };
...@@ -2790,12 +2646,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2790,12 +2646,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1_out, ffn_eltadd1_out,
fused_multi_transformer_fuse_qkv_pattern); fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out,
ffn_dropout_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out, GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out,
ffn_eltadd_out, ffn_eltadd_out,
fused_multi_transformer_fuse_qkv_pattern) fused_multi_transformer_fuse_qkv_pattern)
...@@ -2827,11 +2677,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2827,11 +2677,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out,
softmax_qk_out, softmax_qk_out,
fused_multi_transformer_fuse_qkv_pattern); fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out,
dropout_qk_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern); matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern);
...@@ -2873,12 +2718,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2873,12 +2718,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out, GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out,
eltadd_linear_out, eltadd_linear_out,
fused_multi_transformer_fuse_qkv_pattern) fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear,
dropout_linear,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out,
dropout_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern) eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern)
...@@ -2893,11 +2732,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2893,11 +2732,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
matmul0_w, matmul0_w,
eltadd0_b, eltadd0_b,
eltadd_qk_b, eltadd_qk_b,
dropout_qk,
reshape2_0, reshape2_0,
matmul_linear_w, matmul_linear_w,
eltadd_linear_b, eltadd_linear_b,
dropout_linear,
ffn_layer_norm, ffn_layer_norm,
ffn_layer_norm_scale, ffn_layer_norm_scale,
ffn_layer_norm_bias, ffn_layer_norm_bias,
...@@ -2907,12 +2744,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2907,12 +2744,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_matmul1_w, ffn_matmul1_w,
ffn_eltadd0_b, ffn_eltadd0_b,
ffn_eltadd1_b, ffn_eltadd1_b,
ffn_dropout,
ffn_output); ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm, std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean, layer_norm_mean,
layer_norm_variance, layer_norm_variance,
layer_norm_out, layer_norm_out,
...@@ -2944,8 +2778,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2944,8 +2778,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
eltadd_qk_out, eltadd_qk_out,
softmax_qk, softmax_qk,
softmax_qk_out, softmax_qk_out,
dropout_qk,
dropout_qk_out,
transpose2_qkv, transpose2_qkv,
transpose2_qkv_out, transpose2_qkv_out,
matmul_qkv, matmul_qkv,
...@@ -2954,19 +2786,13 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2954,19 +2786,13 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
transpose2_qkv, transpose2_qkv,
transpose2_qkv_out, transpose2_qkv_out,
matmul_linear, matmul_linear,
matmul_linear_w,
matmul_linear_out, matmul_linear_out,
c_allreduce_sum, c_allreduce_sum,
c_allreduce_sum_out, c_allreduce_sum_out,
eltadd_linear, eltadd_linear,
eltadd_linear_b,
eltadd_linear_out, eltadd_linear_out,
dropout_linear,
dropout_linear_out,
eltadd_out, eltadd_out,
ffn_layer_norm, ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean, ffn_layer_norm_mean,
ffn_layer_norm_variance, ffn_layer_norm_variance,
ffn_layer_norm_out, ffn_layer_norm_out,
...@@ -2984,8 +2810,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ...@@ -2984,8 +2810,6 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
ffn_eltadd1_out, ffn_eltadd1_out,
ffn_gelu, ffn_gelu,
ffn_gelu_out, ffn_gelu_out,
ffn_dropout,
ffn_dropout_out,
ffn_eltadd_out}); ffn_eltadd_out});
// Remove unneeded nodes. // Remove unneeded nodes.
......
...@@ -88,8 +88,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase { ...@@ -88,8 +88,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out); PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul // QK, V matmul
PATTERN_DECL_NODE(matmul_qkv); PATTERN_DECL_NODE(matmul_qkv);
...@@ -106,8 +104,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase { ...@@ -106,8 +104,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out); PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(eltadd_out) PATTERN_DECL_NODE(eltadd_out)
...@@ -137,8 +133,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase { ...@@ -137,8 +133,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out); PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out) PATTERN_DECL_NODE(ffn_eltadd_out)
...@@ -193,8 +187,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase { ...@@ -193,8 +187,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out); PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul // QK, V matmul
PATTERN_DECL_NODE(matmul_qkv); PATTERN_DECL_NODE(matmul_qkv);
...@@ -211,8 +203,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase { ...@@ -211,8 +203,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out); PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(eltadd_out) PATTERN_DECL_NODE(eltadd_out)
...@@ -239,8 +229,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase { ...@@ -239,8 +229,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out); PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out) PATTERN_DECL_NODE(ffn_eltadd_out)
...@@ -299,8 +287,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern ...@@ -299,8 +287,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out); PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul // QK, V matmul
PATTERN_DECL_NODE(matmul_qkv); PATTERN_DECL_NODE(matmul_qkv);
...@@ -319,8 +305,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern ...@@ -319,8 +305,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out); PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(eltadd_out) PATTERN_DECL_NODE(eltadd_out)
...@@ -351,8 +335,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern ...@@ -351,8 +335,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out); PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out) PATTERN_DECL_NODE(ffn_eltadd_out)
......
...@@ -85,13 +85,11 @@ TEST(FusedMultiTransformerDecoderPass, basic) { ...@@ -85,13 +85,11 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
// (transpose_0, transpose_1) matmul -> matmul_qk // (transpose_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk // (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv // (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv // (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear // (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear // (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out // (eltadd_out) elementwise_add -> attention_out
// //
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
...@@ -100,8 +98,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) { ...@@ -100,8 +98,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
// (ffn_eltadd0) gelu -> ffn_gelu // (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
Layers layers; Layers layers;
// MHA: pre LayerNorm // MHA: pre LayerNorm
...@@ -154,10 +151,9 @@ TEST(FusedMultiTransformerDecoderPass, basic) { ...@@ -154,10 +151,9 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul // MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v); auto* matmul_qkv = layers.matmul_v2(softmax_qk, concat_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
...@@ -170,9 +166,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) { ...@@ -170,9 +166,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
auto* linear_eltadd_out = auto* linear_eltadd_out =
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2); layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
auto* dropout_qkv = auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm // FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
...@@ -195,9 +189,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) { ...@@ -195,9 +189,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
auto* ffn_eltadd1_out = auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2); layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add layers.elementwise_add(attention_out, ffn_eltadd1_out);
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
...@@ -215,12 +207,12 @@ TEST(FusedMultiTransformerDecoderPass, basic) { ...@@ -215,12 +207,12 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(num_nodes_before, PADDLE_ENFORCE_EQ(num_nodes_before,
num_nodes_after + 72, num_nodes_after + 60,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_pass, The " "After the fused_multi_transformer_decoder_pass, The "
"node num in graph " "node num in graph "
"should be %d, but the result is %d", "should be %d, but the result is %d",
num_nodes_before - 72, num_nodes_before - 60,
num_nodes_after)); num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1, 1,
...@@ -253,13 +245,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -253,13 +245,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
// (split_q, split_k) matmul -> matmul_qk // (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk // (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv // (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv // (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear // (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear // (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out // (eltadd_out) elementwise_add -> attention_out
// //
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
...@@ -268,8 +258,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -268,8 +258,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
// (ffn_eltadd0) gelu -> ffn_gelu // (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
// //
// (transpose_1, transpose_2) while -> decoder block // (transpose_1, transpose_2) while -> decoder block
...@@ -313,10 +302,9 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -313,10 +302,9 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul // MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v); auto* matmul_qkv = layers.matmul_v2(softmax_qk, concat_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
...@@ -329,9 +317,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -329,9 +317,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
auto* linear_eltadd_out = auto* linear_eltadd_out =
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2); layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
auto* dropout_qkv = auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm // FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
...@@ -354,9 +340,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -354,9 +340,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
auto* ffn_eltadd1_out = auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2); layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add layers.elementwise_add(attention_out, ffn_eltadd1_out);
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
...@@ -375,11 +359,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -375,11 +359,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
num_nodes_before, num_nodes_before,
num_nodes_after + 62, num_nodes_after + 50,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, " "After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d", "The node num in graph should be %d, but the result is %d",
num_nodes_before - 62, num_nodes_before - 50,
num_nodes_after)); num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1, 1,
...@@ -413,14 +397,12 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -413,14 +397,12 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
// (split_q, split_k) matmul -> matmul_qk // (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk // (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv // (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv // (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear // (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) c_allreduce_sum -> c_all_reduce_out // (matmul_linear) c_allreduce_sum -> c_all_reduce_out
// (matmul_linear) elementwise_add -> eltadd_linear // (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out // (eltadd_out) elementwise_add -> attention_out
// //
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
...@@ -431,8 +413,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -431,8 +413,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
// (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1) c_allreduce_sum -> c_allreduce_out // (ffn_matmul1) c_allreduce_sum -> c_allreduce_out
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
// //
// (transpose_1, transpose_2) while -> decoder block // (transpose_1, transpose_2) while -> decoder block
...@@ -477,10 +458,9 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -477,10 +458,9 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul // MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v); auto* matmul_qkv = layers.matmul_v2(softmax_qk, concat_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
...@@ -494,9 +474,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -494,9 +474,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
auto* linear_eltadd_out = auto* linear_eltadd_out =
layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2); layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2);
auto* dropout_qkv = auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm // FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
...@@ -521,9 +499,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -521,9 +499,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
auto* ffn_eltadd1_out = auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_c_allreduce_out, ffn_bias1, nullptr, 2); layers.elementwise_add(ffn_c_allreduce_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add layers.elementwise_add(attention_out, ffn_eltadd1_out);
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
...@@ -544,11 +520,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -544,11 +520,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
num_nodes_before, num_nodes_before,
num_nodes_after + 70, num_nodes_after + 58,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, " "After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d", "The node num in graph should be %d, but the result is %d",
num_nodes_before - 70, num_nodes_before - 58,
num_nodes_after)); num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1, 1,
......
...@@ -227,15 +227,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { ...@@ -227,15 +227,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax") ->assert_is_op_output("softmax")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("dropout"); ->assert_is_op_input("matmul_v2", "X");
auto* dropout_qk =
pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout");
auto* dropout_qk_out_var =
pattern->NewNode(dropout_qk_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv
// QK path Linsk // QK path Linsk
matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var}) matmul_qk->LinksFrom({transpose2_0_out_var, transpose2_1_out_var})
...@@ -243,7 +235,6 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { ...@@ -243,7 +235,6 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var}); .LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var});
// QKV path Nodes // QKV path Nodes
auto* matmul_qkv = auto* matmul_qkv =
...@@ -284,14 +275,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { ...@@ -284,14 +275,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("dropout"); ->assert_is_op_input("elementwise_add");
auto* dropout_linear =
pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout");
auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out = auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
...@@ -300,7 +284,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { ...@@ -300,7 +284,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
->AsIntermediate(); ->AsIntermediate();
// QKV path Links // QKV path Links
matmul_qkv->LinksFrom({dropout_qk_out_var, transpose2_2_out_var}) matmul_qkv->LinksFrom({softmax_qk_out_var, transpose2_2_out_var})
.LinksTo({matmul_qkv_out_var}); .LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var}) transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var}); .LinksTo({transpose2_qkv_out_var});
...@@ -310,9 +294,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { ...@@ -310,9 +294,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
.LinksTo({matmul_linear_out_var}); .LinksTo({matmul_linear_out_var});
eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var}) eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var}); .LinksTo({eltadd_linear_out_var});
dropout_linear->LinksFrom({eltadd_linear_out_var}) eltadd_out->LinksFrom({input0, eltadd_linear_out_var})
.LinksTo({dropout_linear_out_var});
eltadd_out->LinksFrom({input0, dropout_linear_out_var})
.LinksTo({attention_output}); .LinksTo({attention_output});
// while loop // while loop
...@@ -352,7 +334,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { ...@@ -352,7 +334,7 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
ffn_layer_norm_mean_var, ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var}); ffn_layer_norm_variance_var});
// Feed Forward fc1 -> gelu -> fc2 -> dropout // Feed Forward fc1 -> gelu -> fc2
auto* ffn_matmul0 = auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
...@@ -397,13 +379,6 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { ...@@ -397,13 +379,6 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("dropout");
auto* ffn_dropout =
pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout");
auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add"); ->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out = auto* ffn_eltadd_out =
...@@ -421,9 +396,8 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() { ...@@ -421,9 +396,8 @@ PDNode* FusedMultiTransformerEncoderPattern::operator()() {
.LinksTo({ffn_matmul1_out_var}); .LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var}); .LinksTo({ffn_eltadd1_out_var});
ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var}) ffn_eltadd_out->LinksFrom({attention_output, ffn_eltadd1_out_var})
.LinksTo({ffn_output}); .LinksTo({ffn_output});
return ffn_output; return ffn_output;
...@@ -545,15 +519,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -545,15 +519,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax") ->assert_is_op_output("softmax")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("dropout"); ->assert_is_op_input("matmul_v2", "X");
auto* dropout_qk =
pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout");
auto* dropout_qk_out_var =
pattern->NewNode(dropout_qk_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv
// QK path Linsk // QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var}) matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var})
...@@ -561,7 +527,6 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -561,7 +527,6 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var}); .LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var});
// QKV path Nodes // QKV path Nodes
auto* matmul_qkv = auto* matmul_qkv =
...@@ -602,14 +567,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -602,14 +567,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("dropout"); ->assert_is_op_input("elementwise_add");
auto* dropout_linear =
pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout");
auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out = auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
...@@ -618,7 +576,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -618,7 +576,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
->AsIntermediate(); ->AsIntermediate();
// QKV path Links // QKV path Links
matmul_qkv->LinksFrom({dropout_qk_out_var, split0_v_out_var}) matmul_qkv->LinksFrom({softmax_qk_out_var, split0_v_out_var})
.LinksTo({matmul_qkv_out_var}); .LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var}) transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var}); .LinksTo({transpose2_qkv_out_var});
...@@ -628,9 +586,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -628,9 +586,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
.LinksTo({matmul_linear_out_var}); .LinksTo({matmul_linear_out_var});
eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var}) eltadd_linear->LinksFrom({matmul_linear_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var}); .LinksTo({eltadd_linear_out_var});
dropout_linear->LinksFrom({eltadd_linear_out_var}) eltadd_out->LinksFrom({input0, eltadd_linear_out_var})
.LinksTo({dropout_linear_out_var});
eltadd_out->LinksFrom({input0, dropout_linear_out_var})
.LinksTo({attention_output}); .LinksTo({attention_output});
// Feed Forward LayerNorm Nodes // Feed Forward LayerNorm Nodes
...@@ -666,7 +622,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -666,7 +622,7 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
ffn_layer_norm_mean_var, ffn_layer_norm_mean_var,
ffn_layer_norm_variance_var}); ffn_layer_norm_variance_var});
// Feed Forward fc1 -> gelu -> fc2 -> dropout // Feed Forward fc1 -> gelu -> fc2
auto* ffn_matmul0 = auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
...@@ -711,13 +667,6 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -711,13 +667,6 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("dropout");
auto* ffn_dropout =
pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout");
auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add"); ->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out = auto* ffn_eltadd_out =
...@@ -735,9 +684,8 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -735,9 +684,8 @@ PDNode* FusedMultiTransformerEncoderFuseQKVPattern::operator()() {
.LinksTo({ffn_matmul1_out_var}); .LinksTo({ffn_matmul1_out_var});
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var}) ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var}); .LinksTo({ffn_eltadd1_out_var});
ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var}) ffn_eltadd_out->LinksFrom({attention_output, ffn_eltadd1_out_var})
.LinksTo({ffn_output}); .LinksTo({ffn_output});
return ffn_output; return ffn_output;
...@@ -868,15 +816,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -868,15 +816,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr()) auto* softmax_qk_out_var = pattern->NewNode(softmax_qk_out_repr())
->assert_is_op_output("softmax") ->assert_is_op_output("softmax")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("dropout"); ->assert_is_op_input("matmul_v2", "X");
auto* dropout_qk =
pattern->NewNode(dropout_qk_repr())->assert_is_op("dropout");
auto* dropout_qk_out_var =
pattern->NewNode(dropout_qk_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate()
->assert_is_op_input("matmul_v2", "X"); // -> matmul_qkv
// QK path Linsk // QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var}) matmul_qk->LinksFrom({split0_q_out_var, split0_k_out_var})
...@@ -884,7 +824,6 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -884,7 +824,6 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var}) eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var}); .LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
dropout_qk->LinksFrom({softmax_qk_out_var}).LinksTo({dropout_qk_out_var});
// QKV path Nodes // QKV path Nodes
auto* matmul_qkv = auto* matmul_qkv =
...@@ -933,14 +872,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -933,14 +872,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr()) auto* eltadd_linear_out_var = pattern->NewNode(eltadd_linear_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("dropout"); ->assert_is_op_input("elementwise_add");
auto* dropout_linear =
pattern->NewNode(dropout_linear_repr())->assert_is_op("dropout");
auto* dropout_linear_out_var = pattern->NewNode(dropout_linear_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add");
auto* eltadd_out = auto* eltadd_out =
pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltadd_out_repr())->assert_is_op("elementwise_add");
...@@ -949,7 +881,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -949,7 +881,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
->AsIntermediate(); ->AsIntermediate();
// QKV path Links // QKV path Links
matmul_qkv->LinksFrom({dropout_qk_out_var, split0_v_out_var}) matmul_qkv->LinksFrom({softmax_qk_out_var, split0_v_out_var})
.LinksTo({matmul_qkv_out_var}); .LinksTo({matmul_qkv_out_var});
transpose2_qkv->LinksFrom({matmul_qkv_out_var}) transpose2_qkv->LinksFrom({matmul_qkv_out_var})
.LinksTo({transpose2_qkv_out_var}); .LinksTo({transpose2_qkv_out_var});
...@@ -961,9 +893,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -961,9 +893,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
.LinksTo({c_allreduce_sum_out_var}); .LinksTo({c_allreduce_sum_out_var});
eltadd_linear->LinksFrom({c_allreduce_sum_out_var, eltadd_linear_b_var}) eltadd_linear->LinksFrom({c_allreduce_sum_out_var, eltadd_linear_b_var})
.LinksTo({eltadd_linear_out_var}); .LinksTo({eltadd_linear_out_var});
dropout_linear->LinksFrom({eltadd_linear_out_var}) eltadd_out->LinksFrom({input0, eltadd_linear_out_var})
.LinksTo({dropout_linear_out_var});
eltadd_out->LinksFrom({input0, dropout_linear_out_var})
.LinksTo({attention_output}); .LinksTo({attention_output});
// Feed Forward LayerNorm Nodes // Feed Forward LayerNorm Nodes
...@@ -1009,7 +939,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -1009,7 +939,7 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
ffn_c_identity->LinksFrom({ffn_layer_norm_out_var}) ffn_c_identity->LinksFrom({ffn_layer_norm_out_var})
.LinksTo({ffn_c_identity_out_var}); .LinksTo({ffn_c_identity_out_var});
// Feed Forward fc1 -> gelu -> fc2 -> dropout // Feed Forward fc1 -> gelu -> fc2
auto* ffn_matmul0 = auto* ffn_matmul0 =
pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2"); pattern->NewNode(ffn_matmul0_repr())->assert_is_op("matmul_v2");
auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr()) auto* ffn_matmul0_w_var = pattern->NewNode(ffn_matmul0_w_repr())
...@@ -1063,13 +993,6 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -1063,13 +993,6 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr()) auto* ffn_eltadd1_out_var = pattern->NewNode(ffn_eltadd1_out_repr())
->assert_is_op_output("elementwise_add") ->assert_is_op_output("elementwise_add")
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("dropout");
auto* ffn_dropout =
pattern->NewNode(ffn_dropout_repr())->assert_is_op("dropout");
auto* ffn_dropout_out_var = pattern->NewNode(ffn_dropout_out_repr())
->assert_is_op_output("dropout")
->AsIntermediate()
->assert_is_op_input("elementwise_add"); ->assert_is_op_input("elementwise_add");
auto* ffn_eltadd_out = auto* ffn_eltadd_out =
...@@ -1089,9 +1012,8 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() { ...@@ -1089,9 +1012,8 @@ PDNode* MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern::operator()() {
.LinksTo({ffn_c_allreduce_sum_out_var}); .LinksTo({ffn_c_allreduce_sum_out_var});
ffn_eltadd1->LinksFrom({ffn_c_allreduce_sum_out_var, ffn_eltadd1_b_var}) ffn_eltadd1->LinksFrom({ffn_c_allreduce_sum_out_var, ffn_eltadd1_b_var})
.LinksTo({ffn_eltadd1_out_var}); .LinksTo({ffn_eltadd1_out_var});
ffn_dropout->LinksFrom({ffn_eltadd1_out_var}).LinksTo({ffn_dropout_out_var});
ffn_eltadd_out->LinksFrom({attention_output, ffn_dropout_out_var}) ffn_eltadd_out->LinksFrom({attention_output, ffn_eltadd1_out_var})
.LinksTo({ffn_output}); .LinksTo({ffn_output});
return ffn_output; return ffn_output;
...@@ -1253,11 +1175,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1253,11 +1175,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
Node* transpose2_1_out, Node* transpose2_1_out,
Node* transpose2_2_out, Node* transpose2_2_out,
Node* eltadd_qk_b, Node* eltadd_qk_b,
Node* dropout_qk,
Node* reshape2_0, Node* reshape2_0,
Node* matmul_linear_w, Node* matmul_linear_w,
Node* eltadd_linear_b, Node* eltadd_linear_b,
Node* dropout_linear,
Node* while0, Node* while0,
Node* ffn_layer_norm, Node* ffn_layer_norm,
Node* ffn_layer_norm_scale, Node* ffn_layer_norm_scale,
...@@ -1268,7 +1188,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1268,7 +1188,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
Node* ffn_matmul1_w, Node* ffn_matmul1_w,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
Node* ffn_dropout,
Node* ffn_output) { Node* ffn_output) {
auto reshape_desc = reshape2_0->Op(); auto reshape_desc = reshape2_0->Op();
int num_head = int num_head =
...@@ -1375,7 +1294,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1375,7 +1294,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
fill_const_op_desc.SetAttr("input_dim_idx", 0); fill_const_op_desc.SetAttr("input_dim_idx", 0);
fill_const_op_desc.SetAttr("output_dim_idx", 1); fill_const_op_desc.SetAttr("output_dim_idx", 1);
fill_const_op_desc.SetAttr("value", 0); fill_const_op_desc.SetAttr("value", 0);
fill_const_op_desc.SetAttr("dtype", static_cast<int>(proto::VarType::FP32)); fill_const_op_desc.SetAttr(
"dtype",
static_cast<int>(framework::TransToProtoVarType(wq_tensor->dtype())));
auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()});
...@@ -1409,15 +1330,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1409,15 +1330,8 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
fused_multi_transformer_op_desc.SetAttr( fused_multi_transformer_op_desc.SetAttr(
"epsilon", layer_norm->Op()->GetAttr("epsilon")); "epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute fused_multi_transformer_op_desc.SetAttr("is_test", true);
auto* dropout_op = dropout_linear->Op(); fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
fused_multi_transformer_op_desc.SetAttr(
"dropout_rate", dropout_op->GetAttr("dropout_prob"));
fused_multi_transformer_op_desc.SetAttr("is_test",
dropout_op->GetAttr("is_test"));
fused_multi_transformer_op_desc.SetAttr(
"dropout_implementation",
dropout_op->GetAttr("dropout_implementation"));
auto* fused_multi_transformer = auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc); graph->CreateOpNode(&fused_multi_transformer_op_desc);
...@@ -1433,6 +1347,15 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1433,6 +1347,15 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
IR_NODE_LINK_TO(fill_const_op, cache_kv); IR_NODE_LINK_TO(fill_const_op, cache_kv);
IR_NODE_LINK_TO(cache_kv, fused_multi_transformer); IR_NODE_LINK_TO(cache_kv, fused_multi_transformer);
IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_matmul1_w, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_eltadd1_b, fused_multi_transformer);
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); IR_NODE_LINK_TO(fused_multi_transformer, ffn_output);
// rewrite while OP input // rewrite while OP input
...@@ -1620,11 +1543,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1620,11 +1543,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd1_out, ffn_eltadd1_out, fused_multi_transformer_pattern); ffn_eltadd1_out, ffn_eltadd1_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout, ffn_dropout, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout_out, ffn_dropout_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_pattern) ffn_eltadd_out, ffn_eltadd_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
...@@ -1668,11 +1586,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1668,11 +1586,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
softmax_qk, softmax_qk, fused_multi_transformer_pattern); softmax_qk, softmax_qk, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
softmax_qk_out, softmax_qk_out, fused_multi_transformer_pattern); softmax_qk_out, softmax_qk_out, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk, dropout_qk, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk_out, dropout_qk_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_pattern); matmul_qkv, matmul_qkv, fused_multi_transformer_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
...@@ -1700,11 +1613,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1700,11 +1613,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
eltadd_linear_b, eltadd_linear_b, fused_multi_transformer_pattern) eltadd_linear_b, eltadd_linear_b, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_pattern) eltadd_linear_out, eltadd_linear_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
dropout_linear, dropout_linear, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH(
dropout_linear_out, dropout_linear_out, fused_multi_transformer_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_pattern) eltadd_out, eltadd_out, fused_multi_transformer_pattern)
...@@ -1723,11 +1631,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1723,11 +1631,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
transpose2_1_out, transpose2_1_out,
transpose2_2_out, transpose2_2_out,
eltadd_qk_b, eltadd_qk_b,
dropout_qk,
reshape2_0, reshape2_0,
matmul_linear_w, matmul_linear_w,
eltadd_linear_b, eltadd_linear_b,
dropout_linear,
while0, while0,
ffn_layer_norm, ffn_layer_norm,
ffn_layer_norm_scale, ffn_layer_norm_scale,
...@@ -1738,12 +1644,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1738,12 +1644,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
ffn_matmul1_w, ffn_matmul1_w,
ffn_eltadd0_b, ffn_eltadd0_b,
ffn_eltadd1_b, ffn_eltadd1_b,
ffn_dropout,
ffn_output); ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm, std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean, layer_norm_mean,
layer_norm_variance, layer_norm_variance,
layer_norm_out, layer_norm_out,
...@@ -1777,8 +1680,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1777,8 +1680,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
eltadd_qk_out, eltadd_qk_out,
softmax_qk, softmax_qk,
softmax_qk_out, softmax_qk_out,
dropout_qk,
dropout_qk_out,
transpose2_qkv, transpose2_qkv,
transpose2_qkv_out, transpose2_qkv_out,
matmul_qkv, matmul_qkv,
...@@ -1787,17 +1688,11 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1787,17 +1688,11 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
transpose2_qkv, transpose2_qkv,
transpose2_qkv_out, transpose2_qkv_out,
matmul_linear, matmul_linear,
matmul_linear_w,
matmul_linear_out, matmul_linear_out,
eltadd_linear, eltadd_linear,
eltadd_linear_b,
eltadd_linear_out, eltadd_linear_out,
dropout_linear,
dropout_linear_out,
eltadd_out, eltadd_out,
ffn_layer_norm, ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean, ffn_layer_norm_mean,
ffn_layer_norm_variance, ffn_layer_norm_variance,
ffn_layer_norm_out, ffn_layer_norm_out,
...@@ -1811,8 +1706,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1811,8 +1706,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
ffn_eltadd1_out, ffn_eltadd1_out,
ffn_gelu, ffn_gelu,
ffn_gelu_out, ffn_gelu_out,
ffn_dropout,
ffn_dropout_out,
ffn_eltadd_out}); ffn_eltadd_out});
// Remove unneeded nodes. // Remove unneeded nodes.
...@@ -2016,11 +1909,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2016,11 +1909,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node* split0_k_out, Node* split0_k_out,
Node* split0_v_out, Node* split0_v_out,
Node* eltadd_qk_b, Node* eltadd_qk_b,
Node* dropout_qk,
Node* reshape2_0, Node* reshape2_0,
Node* matmul_linear_w, Node* matmul_linear_w,
Node* eltadd_linear_b, Node* eltadd_linear_b,
Node* dropout_linear,
Node* while0, Node* while0,
Node* ffn_layer_norm, Node* ffn_layer_norm,
Node* ffn_layer_norm_scale, Node* ffn_layer_norm_scale,
...@@ -2031,7 +1922,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2031,7 +1922,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node* ffn_matmul1_w, Node* ffn_matmul1_w,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
Node* ffn_dropout,
Node* ffn_output) { Node* ffn_output) {
auto reshape_desc = reshape2_0->Op(); auto reshape_desc = reshape2_0->Op();
int num_head = int num_head =
...@@ -2104,7 +1994,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2104,7 +1994,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
fill_const_op_desc.SetAttr("input_dim_idx", 0); fill_const_op_desc.SetAttr("input_dim_idx", 0);
fill_const_op_desc.SetAttr("output_dim_idx", 1); fill_const_op_desc.SetAttr("output_dim_idx", 1);
fill_const_op_desc.SetAttr("value", 0); fill_const_op_desc.SetAttr("value", 0);
fill_const_op_desc.SetAttr("dtype", static_cast<int>(proto::VarType::FP32)); fill_const_op_desc.SetAttr("dtype",
static_cast<int>(framework::TransToProtoVarType(
qkv_w_tensor->dtype())));
auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()});
...@@ -2139,14 +2031,8 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2139,14 +2031,8 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
"epsilon", layer_norm->Op()->GetAttr("epsilon")); "epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute // output dropout attribute
auto* dropout_op = dropout_linear->Op(); fused_multi_transformer_op_desc.SetAttr("is_test", true);
fused_multi_transformer_op_desc.SetAttr( fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
"dropout_rate", dropout_op->GetAttr("dropout_prob"));
fused_multi_transformer_op_desc.SetAttr("is_test",
dropout_op->GetAttr("is_test"));
fused_multi_transformer_op_desc.SetAttr(
"dropout_implementation",
dropout_op->GetAttr("dropout_implementation"));
auto* fused_multi_transformer = auto* fused_multi_transformer =
graph->CreateOpNode(&fused_multi_transformer_op_desc); graph->CreateOpNode(&fused_multi_transformer_op_desc);
...@@ -2162,6 +2048,15 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2162,6 +2048,15 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
IR_NODE_LINK_TO(fill_const_op, cache_kv); IR_NODE_LINK_TO(fill_const_op, cache_kv);
IR_NODE_LINK_TO(cache_kv, fused_multi_transformer); IR_NODE_LINK_TO(cache_kv, fused_multi_transformer);
IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_matmul1_w, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_eltadd1_b, fused_multi_transformer);
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); IR_NODE_LINK_TO(fused_multi_transformer, ffn_output);
// rewrite while OP input // rewrite while OP input
...@@ -2315,12 +2210,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2315,12 +2210,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_eltadd1_out, ffn_eltadd1_out,
fused_multi_transformer_fuse_qkv_pattern); fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out,
ffn_dropout_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out, GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out,
ffn_eltadd_out, ffn_eltadd_out,
fused_multi_transformer_fuse_qkv_pattern) fused_multi_transformer_fuse_qkv_pattern)
...@@ -2352,11 +2241,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2352,11 +2241,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out,
softmax_qk_out, softmax_qk_out,
fused_multi_transformer_fuse_qkv_pattern); fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out,
dropout_qk_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern); matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern);
...@@ -2392,12 +2276,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2392,12 +2276,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out, GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out,
eltadd_linear_out, eltadd_linear_out,
fused_multi_transformer_fuse_qkv_pattern) fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear,
dropout_linear,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out,
dropout_linear_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern) eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern)
...@@ -2416,11 +2294,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2416,11 +2294,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
split0_k_out, split0_k_out,
split0_v_out, split0_v_out,
eltadd_qk_b, eltadd_qk_b,
dropout_qk,
reshape2_0, reshape2_0,
matmul_linear_w, matmul_linear_w,
eltadd_linear_b, eltadd_linear_b,
dropout_linear,
while0, while0,
ffn_layer_norm, ffn_layer_norm,
ffn_layer_norm_scale, ffn_layer_norm_scale,
...@@ -2431,12 +2307,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2431,12 +2307,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_matmul1_w, ffn_matmul1_w,
ffn_eltadd0_b, ffn_eltadd0_b,
ffn_eltadd1_b, ffn_eltadd1_b,
ffn_dropout,
ffn_output); ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm, std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean, layer_norm_mean,
layer_norm_variance, layer_norm_variance,
layer_norm_out, layer_norm_out,
...@@ -2458,8 +2331,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2458,8 +2331,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
eltadd_qk_out, eltadd_qk_out,
softmax_qk, softmax_qk,
softmax_qk_out, softmax_qk_out,
dropout_qk,
dropout_qk_out,
transpose2_qkv, transpose2_qkv,
transpose2_qkv_out, transpose2_qkv_out,
matmul_qkv, matmul_qkv,
...@@ -2468,17 +2339,11 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2468,17 +2339,11 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
transpose2_qkv, transpose2_qkv,
transpose2_qkv_out, transpose2_qkv_out,
matmul_linear, matmul_linear,
matmul_linear_w,
matmul_linear_out, matmul_linear_out,
eltadd_linear, eltadd_linear,
eltadd_linear_b,
eltadd_linear_out, eltadd_linear_out,
dropout_linear,
dropout_linear_out,
eltadd_out, eltadd_out,
ffn_layer_norm, ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean, ffn_layer_norm_mean,
ffn_layer_norm_variance, ffn_layer_norm_variance,
ffn_layer_norm_out, ffn_layer_norm_out,
...@@ -2492,8 +2357,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2492,8 +2357,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_eltadd1_out, ffn_eltadd1_out,
ffn_gelu, ffn_gelu,
ffn_gelu_out, ffn_gelu_out,
ffn_dropout,
ffn_dropout_out,
ffn_eltadd_out}); ffn_eltadd_out});
// Remove unneeded nodes. // Remove unneeded nodes.
...@@ -2700,11 +2563,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2700,11 +2563,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node* split0_k_out, Node* split0_k_out,
Node* split0_v_out, Node* split0_v_out,
Node* eltadd_qk_b, Node* eltadd_qk_b,
Node* dropout_qk,
Node* reshape2_0, Node* reshape2_0,
Node* matmul_linear_w, Node* matmul_linear_w,
Node* eltadd_linear_b, Node* eltadd_linear_b,
Node* dropout_linear,
Node* while0, Node* while0,
Node* ffn_layer_norm, Node* ffn_layer_norm,
Node* ffn_layer_norm_scale, Node* ffn_layer_norm_scale,
...@@ -2715,7 +2576,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2715,7 +2576,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node* ffn_matmul1_w, Node* ffn_matmul1_w,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
Node* ffn_dropout,
Node* ffn_output) { Node* ffn_output) {
auto reshape_desc = reshape2_0->Op(); auto reshape_desc = reshape2_0->Op();
int num_head = int num_head =
...@@ -2789,7 +2649,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2789,7 +2649,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
fill_const_op_desc.SetAttr("input_dim_idx", 0); fill_const_op_desc.SetAttr("input_dim_idx", 0);
fill_const_op_desc.SetAttr("output_dim_idx", 1); fill_const_op_desc.SetAttr("output_dim_idx", 1);
fill_const_op_desc.SetAttr("value", 0); fill_const_op_desc.SetAttr("value", 0);
fill_const_op_desc.SetAttr("dtype", static_cast<int>(proto::VarType::FP32)); fill_const_op_desc.SetAttr("dtype",
static_cast<int>(framework::TransToProtoVarType(
qkv_w_tensor->dtype())));
auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()});
...@@ -2824,14 +2686,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2824,14 +2686,8 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
"epsilon", layer_norm->Op()->GetAttr("epsilon")); "epsilon", layer_norm->Op()->GetAttr("epsilon"));
// output dropout attribute // output dropout attribute
auto* dropout_op = dropout_linear->Op(); fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
fused_multi_transformer_op_desc.SetAttr( fused_multi_transformer_op_desc.SetAttr("is_test", true);
"dropout_rate", dropout_op->GetAttr("dropout_prob"));
fused_multi_transformer_op_desc.SetAttr("is_test",
dropout_op->GetAttr("is_test"));
fused_multi_transformer_op_desc.SetAttr(
"dropout_implementation",
dropout_op->GetAttr("dropout_implementation"));
// parallel ring id // parallel ring id
auto* c_identity_op = c_identity->Op(); auto* c_identity_op = c_identity->Op();
...@@ -2852,6 +2708,15 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2852,6 +2708,15 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
IR_NODE_LINK_TO(fill_const_op, cache_kv); IR_NODE_LINK_TO(fill_const_op, cache_kv);
IR_NODE_LINK_TO(cache_kv, fused_multi_transformer); IR_NODE_LINK_TO(cache_kv, fused_multi_transformer);
IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_layer_norm_scale, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_layer_norm_bias, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_matmul0_w, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_matmul1_w, fused_multi_transformer);
IR_NODE_LINK_TO(ffn_eltadd1_b, fused_multi_transformer);
IR_NODE_LINK_TO(fused_multi_transformer, ffn_output); IR_NODE_LINK_TO(fused_multi_transformer, ffn_output);
// rewrite while OP input // rewrite while OP input
...@@ -3024,12 +2889,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -3024,12 +2889,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_eltadd1_out, ffn_eltadd1_out,
fused_multi_transformer_fuse_qkv_pattern); fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ffn_dropout, ffn_dropout, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_dropout_out,
ffn_dropout_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out, GET_IR_NODE_FROM_SUBGRAPH(ffn_eltadd_out,
ffn_eltadd_out, ffn_eltadd_out,
fused_multi_transformer_fuse_qkv_pattern) fused_multi_transformer_fuse_qkv_pattern)
...@@ -3061,11 +2920,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -3061,11 +2920,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out,
softmax_qk_out, softmax_qk_out,
fused_multi_transformer_fuse_qkv_pattern); fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_qk, dropout_qk, fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH(dropout_qk_out,
dropout_qk_out,
fused_multi_transformer_fuse_qkv_pattern)
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern); matmul_qkv, matmul_qkv, fused_multi_transformer_fuse_qkv_pattern);
...@@ -3107,12 +2961,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -3107,12 +2961,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out, GET_IR_NODE_FROM_SUBGRAPH(eltadd_linear_out,
eltadd_linear_out, eltadd_linear_out,
fused_multi_transformer_fuse_qkv_pattern); fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear,
dropout_linear,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dropout_linear_out,
dropout_linear_out,
fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern); eltadd_out, eltadd_out, fused_multi_transformer_fuse_qkv_pattern);
...@@ -3132,11 +2980,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -3132,11 +2980,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
split0_k_out, split0_k_out,
split0_v_out, split0_v_out,
eltadd_qk_b, eltadd_qk_b,
dropout_qk,
reshape2_0, reshape2_0,
matmul_linear_w, matmul_linear_w,
eltadd_linear_b, eltadd_linear_b,
dropout_linear,
while0, while0,
ffn_layer_norm, ffn_layer_norm,
ffn_layer_norm_scale, ffn_layer_norm_scale,
...@@ -3147,12 +2993,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -3147,12 +2993,9 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_matmul1_w, ffn_matmul1_w,
ffn_eltadd0_b, ffn_eltadd0_b,
ffn_eltadd1_b, ffn_eltadd1_b,
ffn_dropout,
ffn_output); ffn_output);
std::unordered_set<const Node*> marked_nodes({layer_norm, std::unordered_set<const Node*> marked_nodes({layer_norm,
layer_norm_scale,
layer_norm_bias,
layer_norm_mean, layer_norm_mean,
layer_norm_variance, layer_norm_variance,
layer_norm_out, layer_norm_out,
...@@ -3176,8 +3019,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -3176,8 +3019,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
eltadd_qk_out, eltadd_qk_out,
softmax_qk, softmax_qk,
softmax_qk_out, softmax_qk_out,
dropout_qk,
dropout_qk_out,
transpose2_qkv, transpose2_qkv,
transpose2_qkv_out, transpose2_qkv_out,
matmul_qkv, matmul_qkv,
...@@ -3186,19 +3027,13 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -3186,19 +3027,13 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
transpose2_qkv, transpose2_qkv,
transpose2_qkv_out, transpose2_qkv_out,
matmul_linear, matmul_linear,
matmul_linear_w,
matmul_linear_out, matmul_linear_out,
c_allreduce_sum, c_allreduce_sum,
c_allreduce_sum_out, c_allreduce_sum_out,
eltadd_linear, eltadd_linear,
eltadd_linear_b,
eltadd_linear_out, eltadd_linear_out,
dropout_linear,
dropout_linear_out,
eltadd_out, eltadd_out,
ffn_layer_norm, ffn_layer_norm,
ffn_layer_norm_scale,
ffn_layer_norm_bias,
ffn_layer_norm_mean, ffn_layer_norm_mean,
ffn_layer_norm_variance, ffn_layer_norm_variance,
ffn_layer_norm_out, ffn_layer_norm_out,
...@@ -3216,8 +3051,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -3216,8 +3051,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
ffn_eltadd1_out, ffn_eltadd1_out,
ffn_gelu, ffn_gelu,
ffn_gelu_out, ffn_gelu_out,
ffn_dropout,
ffn_dropout_out,
ffn_eltadd_out}); ffn_eltadd_out});
// Remove unneeded nodes. // Remove unneeded nodes.
......
...@@ -82,8 +82,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase { ...@@ -82,8 +82,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out); PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul // QK, V matmul
PATTERN_DECL_NODE(matmul_qkv); PATTERN_DECL_NODE(matmul_qkv);
...@@ -100,8 +98,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase { ...@@ -100,8 +98,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out); PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(eltadd_out) PATTERN_DECL_NODE(eltadd_out)
...@@ -131,8 +127,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase { ...@@ -131,8 +127,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out); PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out) PATTERN_DECL_NODE(ffn_eltadd_out)
...@@ -179,8 +173,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase { ...@@ -179,8 +173,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out); PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul // QK, V matmul
PATTERN_DECL_NODE(matmul_qkv); PATTERN_DECL_NODE(matmul_qkv);
...@@ -200,8 +192,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase { ...@@ -200,8 +192,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out); PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(eltadd_out) PATTERN_DECL_NODE(eltadd_out)
...@@ -228,8 +218,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase { ...@@ -228,8 +218,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out); PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out) PATTERN_DECL_NODE(ffn_eltadd_out)
...@@ -280,8 +268,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern ...@@ -280,8 +268,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out); PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul // QK, V matmul
PATTERN_DECL_NODE(matmul_qkv); PATTERN_DECL_NODE(matmul_qkv);
...@@ -303,8 +289,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern ...@@ -303,8 +289,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out); PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(eltadd_out) PATTERN_DECL_NODE(eltadd_out)
...@@ -335,8 +319,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern ...@@ -335,8 +319,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out); PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out) PATTERN_DECL_NODE(ffn_eltadd_out)
......
...@@ -81,13 +81,11 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -81,13 +81,11 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// (transpose_0, transpose_1) matmul -> matmul_qk // (transpose_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk // (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv // (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv // (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear // (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear // (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out // (eltadd_out) elementwise_add -> attention_out
// //
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
...@@ -96,8 +94,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -96,8 +94,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// (ffn_eltadd0) gelu -> ffn_gelu // (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
// //
// (transpose_1, transpose_2) while -> decoder block // (transpose_1, transpose_2) while -> decoder block
...@@ -149,10 +146,9 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -149,10 +146,9 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk, nullptr, -1); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk, nullptr, -1);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul // MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, transpose_2); auto* matmul_qkv = layers.matmul_v2(softmax_qk, transpose_2);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
...@@ -165,9 +161,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -165,9 +161,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto* linear_eltadd_out = auto* linear_eltadd_out =
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2); layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
auto* dropout_qkv = auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm // FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
...@@ -190,9 +184,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -190,9 +184,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto* ffn_eltadd1_out = auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2); layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add layers.elementwise_add(attention_out, ffn_eltadd1_out);
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
...@@ -210,12 +202,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -210,12 +202,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(num_nodes_before, PADDLE_ENFORCE_EQ(num_nodes_before,
num_nodes_after + 68, num_nodes_after + 56,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_pass, The " "After the fused_multi_transformer_encoder_pass, The "
"node num in graph " "node num in graph "
"should be %d, but the result is %d", "should be %d, but the result is %d",
num_nodes_before - 68, num_nodes_before - 56,
num_nodes_after)); num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1, 1,
...@@ -246,13 +238,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -246,13 +238,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
// (split_q, split_k) matmul -> matmul_qk // (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk // (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv // (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv // (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear // (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear // (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out // (eltadd_out) elementwise_add -> attention_out
// //
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
...@@ -261,8 +251,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -261,8 +251,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
// (ffn_eltadd0) gelu -> ffn_gelu // (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
// //
// (transpose_1, transpose_2) while -> decoder block // (transpose_1, transpose_2) while -> decoder block
...@@ -304,10 +293,9 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -304,10 +293,9 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul // MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, split_v); auto* matmul_qkv = layers.matmul_v2(softmax_qk, split_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
...@@ -320,9 +308,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -320,9 +308,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
auto* linear_eltadd_out = auto* linear_eltadd_out =
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2); layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
auto* dropout_qkv = auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm // FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
...@@ -345,9 +331,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -345,9 +331,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
auto* ffn_eltadd1_out = auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2); layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add layers.elementwise_add(attention_out, ffn_eltadd1_out);
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
...@@ -366,11 +350,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -366,11 +350,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
num_nodes_before, num_nodes_before,
num_nodes_after + 56, num_nodes_after + 44,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_fuse_qkv_pass, " "After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d", "The node num in graph should be %d, but the result is %d",
num_nodes_before - 56, num_nodes_before - 44,
num_nodes_after)); num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1, 1,
...@@ -402,14 +386,12 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -402,14 +386,12 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
// (split_q, split_k) matmul -> matmul_qk // (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk // (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv // (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv // (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear // (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) c_all_reduce -> c_all_reduce_out // (matmul_linear) c_all_reduce -> c_all_reduce_out
// (c_all_reduce_out) elementwise_add -> eltadd_linear // (c_all_reduce_out) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out // (eltadd_out) elementwise_add -> attention_out
// //
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
...@@ -420,8 +402,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -420,8 +402,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
// (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1) c_all_reduce -> ffn_c_all_reduce_out // (ffn_matmul1) c_all_reduce -> ffn_c_all_reduce_out
// (ffn_c_all_reduce_out, ffn_bias1)elementwise_add -> ffn_eltadd1 // (ffn_c_all_reduce_out, ffn_bias1)elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
// //
// (transpose_1, transpose_2) while -> decoder block // (transpose_1, transpose_2) while -> decoder block
...@@ -464,10 +445,9 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -464,10 +445,9 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul // MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, split_v); auto* matmul_qkv = layers.matmul_v2(softmax_qk, split_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
...@@ -481,9 +461,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -481,9 +461,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
auto* linear_eltadd_out = auto* linear_eltadd_out =
layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2); layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2);
auto* dropout_qkv = auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm // FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
...@@ -508,9 +486,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -508,9 +486,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
auto* ffn_eltadd1_out = auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_allreduce_out, ffn_bias1, nullptr, 2); layers.elementwise_add(ffn_allreduce_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add layers.elementwise_add(attention_out, ffn_eltadd1_out);
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
...@@ -531,11 +507,11 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -531,11 +507,11 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
num_nodes_before, num_nodes_before,
num_nodes_after + 64, num_nodes_after + 52,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_fuse_qkv_pass, " "After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d", "The node num in graph should be %d, but the result is %d",
num_nodes_before - 64, num_nodes_before - 52,
num_nodes_after)); num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1, 1,
......
...@@ -39,6 +39,7 @@ namespace ir { ...@@ -39,6 +39,7 @@ namespace ir {
static const char kParamScopeAttr[] = "__param_scope__"; static const char kParamScopeAttr[] = "__param_scope__";
static const std::vector<std::string> support_subgraph_passes = { static const std::vector<std::string> support_subgraph_passes = {
"simplify_with_basic_ops_pass",
"fused_multi_transformer_encoder_pass", "fused_multi_transformer_encoder_pass",
"fused_multi_transformer_decoder_pass", "fused_multi_transformer_decoder_pass",
"fused_multi_transformer_encoder_fuse_qkv_pass", "fused_multi_transformer_encoder_fuse_qkv_pass",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册