未验证 提交 ba4fbe71 编写于 作者: K Kaipeng Deng 提交者: GitHub

[cherry pick] fix memory copy in prepare_data of FusedMultiTransformer pass (#47308)

* fix memory copy in prepare_data. test=develop

* add cache_kv fp16 support. test=develop

* fit for simplify_with_basic_ops_pass. test=develop
上级 7a1cf277
...@@ -88,8 +88,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase { ...@@ -88,8 +88,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out); PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul // QK, V matmul
PATTERN_DECL_NODE(matmul_qkv); PATTERN_DECL_NODE(matmul_qkv);
...@@ -106,8 +104,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase { ...@@ -106,8 +104,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out); PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(eltadd_out) PATTERN_DECL_NODE(eltadd_out)
...@@ -137,8 +133,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase { ...@@ -137,8 +133,6 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out); PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out) PATTERN_DECL_NODE(ffn_eltadd_out)
...@@ -193,8 +187,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase { ...@@ -193,8 +187,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out); PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul // QK, V matmul
PATTERN_DECL_NODE(matmul_qkv); PATTERN_DECL_NODE(matmul_qkv);
...@@ -211,8 +203,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase { ...@@ -211,8 +203,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out); PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(eltadd_out) PATTERN_DECL_NODE(eltadd_out)
...@@ -239,8 +229,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase { ...@@ -239,8 +229,6 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out); PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out) PATTERN_DECL_NODE(ffn_eltadd_out)
...@@ -299,8 +287,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern ...@@ -299,8 +287,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out); PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul // QK, V matmul
PATTERN_DECL_NODE(matmul_qkv); PATTERN_DECL_NODE(matmul_qkv);
...@@ -319,8 +305,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern ...@@ -319,8 +305,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out); PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(eltadd_out) PATTERN_DECL_NODE(eltadd_out)
...@@ -351,8 +335,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern ...@@ -351,8 +335,6 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out); PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out) PATTERN_DECL_NODE(ffn_eltadd_out)
......
...@@ -85,13 +85,11 @@ TEST(FusedMultiTransformerDecoderPass, basic) { ...@@ -85,13 +85,11 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
// (transpose_0, transpose_1) matmul -> matmul_qk // (transpose_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk // (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv // (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv // (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear // (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear // (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out // (eltadd_out) elementwise_add -> attention_out
// //
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
...@@ -100,8 +98,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) { ...@@ -100,8 +98,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
// (ffn_eltadd0) gelu -> ffn_gelu // (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
Layers layers; Layers layers;
// MHA: pre LayerNorm // MHA: pre LayerNorm
...@@ -154,10 +151,9 @@ TEST(FusedMultiTransformerDecoderPass, basic) { ...@@ -154,10 +151,9 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul // MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v); auto* matmul_qkv = layers.matmul_v2(softmax_qk, concat_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
...@@ -170,9 +166,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) { ...@@ -170,9 +166,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
auto* linear_eltadd_out = auto* linear_eltadd_out =
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2); layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
auto* dropout_qkv = auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm // FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
...@@ -195,9 +189,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) { ...@@ -195,9 +189,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
auto* ffn_eltadd1_out = auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2); layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add layers.elementwise_add(attention_out, ffn_eltadd1_out);
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
...@@ -215,12 +207,12 @@ TEST(FusedMultiTransformerDecoderPass, basic) { ...@@ -215,12 +207,12 @@ TEST(FusedMultiTransformerDecoderPass, basic) {
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(num_nodes_before, PADDLE_ENFORCE_EQ(num_nodes_before,
num_nodes_after + 72, num_nodes_after + 60,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_pass, The " "After the fused_multi_transformer_decoder_pass, The "
"node num in graph " "node num in graph "
"should be %d, but the result is %d", "should be %d, but the result is %d",
num_nodes_before - 72, num_nodes_before - 60,
num_nodes_after)); num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1, 1,
...@@ -253,13 +245,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -253,13 +245,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
// (split_q, split_k) matmul -> matmul_qk // (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk // (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv // (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv // (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear // (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear // (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out // (eltadd_out) elementwise_add -> attention_out
// //
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
...@@ -268,8 +258,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -268,8 +258,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
// (ffn_eltadd0) gelu -> ffn_gelu // (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
// //
// (transpose_1, transpose_2) while -> decoder block // (transpose_1, transpose_2) while -> decoder block
...@@ -313,10 +302,9 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -313,10 +302,9 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul // MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v); auto* matmul_qkv = layers.matmul_v2(softmax_qk, concat_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
...@@ -329,9 +317,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -329,9 +317,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
auto* linear_eltadd_out = auto* linear_eltadd_out =
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2); layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
auto* dropout_qkv = auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm // FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
...@@ -354,9 +340,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -354,9 +340,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
auto* ffn_eltadd1_out = auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2); layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add layers.elementwise_add(attention_out, ffn_eltadd1_out);
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
...@@ -375,11 +359,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -375,11 +359,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
num_nodes_before, num_nodes_before,
num_nodes_after + 62, num_nodes_after + 50,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, " "After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d", "The node num in graph should be %d, but the result is %d",
num_nodes_before - 62, num_nodes_before - 50,
num_nodes_after)); num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1, 1,
...@@ -413,14 +397,12 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -413,14 +397,12 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
// (split_q, split_k) matmul -> matmul_qk // (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk // (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv // (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv // (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear // (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) c_allreduce_sum -> c_all_reduce_out // (matmul_linear) c_allreduce_sum -> c_all_reduce_out
// (matmul_linear) elementwise_add -> eltadd_linear // (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out // (eltadd_out) elementwise_add -> attention_out
// //
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
...@@ -431,8 +413,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -431,8 +413,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
// (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1) c_allreduce_sum -> c_allreduce_out // (ffn_matmul1) c_allreduce_sum -> c_allreduce_out
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
// //
// (transpose_1, transpose_2) while -> decoder block // (transpose_1, transpose_2) while -> decoder block
...@@ -477,10 +458,9 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -477,10 +458,9 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul // MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, concat_v); auto* matmul_qkv = layers.matmul_v2(softmax_qk, concat_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
...@@ -494,9 +474,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -494,9 +474,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
auto* linear_eltadd_out = auto* linear_eltadd_out =
layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2); layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2);
auto* dropout_qkv = auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm // FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
...@@ -521,9 +499,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -521,9 +499,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
auto* ffn_eltadd1_out = auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_c_allreduce_out, ffn_bias1, nullptr, 2); layers.elementwise_add(ffn_c_allreduce_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add layers.elementwise_add(attention_out, ffn_eltadd1_out);
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
...@@ -544,11 +520,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { ...@@ -544,11 +520,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
num_nodes_before, num_nodes_before,
num_nodes_after + 70, num_nodes_after + 58,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, " "After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d", "The node num in graph should be %d, but the result is %d",
num_nodes_before - 70, num_nodes_before - 58,
num_nodes_after)); num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1, 1,
......
...@@ -82,8 +82,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase { ...@@ -82,8 +82,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out); PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul // QK, V matmul
PATTERN_DECL_NODE(matmul_qkv); PATTERN_DECL_NODE(matmul_qkv);
...@@ -100,8 +98,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase { ...@@ -100,8 +98,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out); PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(eltadd_out) PATTERN_DECL_NODE(eltadd_out)
...@@ -131,8 +127,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase { ...@@ -131,8 +127,6 @@ struct FusedMultiTransformerEncoderPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out); PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out) PATTERN_DECL_NODE(ffn_eltadd_out)
...@@ -179,8 +173,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase { ...@@ -179,8 +173,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out); PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul // QK, V matmul
PATTERN_DECL_NODE(matmul_qkv); PATTERN_DECL_NODE(matmul_qkv);
...@@ -200,8 +192,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase { ...@@ -200,8 +192,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out); PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(eltadd_out) PATTERN_DECL_NODE(eltadd_out)
...@@ -228,8 +218,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase { ...@@ -228,8 +218,6 @@ struct FusedMultiTransformerEncoderFuseQKVPattern : public PatternBase {
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out); PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out) PATTERN_DECL_NODE(ffn_eltadd_out)
...@@ -280,8 +268,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern ...@@ -280,8 +268,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out); PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(dropout_qk);
PATTERN_DECL_NODE(dropout_qk_out);
// QK, V matmul // QK, V matmul
PATTERN_DECL_NODE(matmul_qkv); PATTERN_DECL_NODE(matmul_qkv);
...@@ -303,8 +289,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern ...@@ -303,8 +289,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
PATTERN_DECL_NODE(eltadd_linear); PATTERN_DECL_NODE(eltadd_linear);
PATTERN_DECL_NODE(eltadd_linear_b); PATTERN_DECL_NODE(eltadd_linear_b);
PATTERN_DECL_NODE(eltadd_linear_out); PATTERN_DECL_NODE(eltadd_linear_out);
PATTERN_DECL_NODE(dropout_linear);
PATTERN_DECL_NODE(dropout_linear_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(eltadd_out) PATTERN_DECL_NODE(eltadd_out)
...@@ -335,8 +319,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern ...@@ -335,8 +319,6 @@ struct MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern
PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD PATTERN_DECL_NODE(ffn_eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(ffn_eltadd1_out); PATTERN_DECL_NODE(ffn_eltadd1_out);
PATTERN_DECL_NODE(ffn_dropout);
PATTERN_DECL_NODE(ffn_dropout_out);
// output elementwise_add // output elementwise_add
PATTERN_DECL_NODE(ffn_eltadd_out) PATTERN_DECL_NODE(ffn_eltadd_out)
......
...@@ -81,13 +81,11 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -81,13 +81,11 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// (transpose_0, transpose_1) matmul -> matmul_qk // (transpose_0, transpose_1) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk // (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv // (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv // (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear // (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear // (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out // (eltadd_out) elementwise_add -> attention_out
// //
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
...@@ -96,8 +94,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -96,8 +94,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
// (ffn_eltadd0) gelu -> ffn_gelu // (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
// //
// (transpose_1, transpose_2) while -> decoder block // (transpose_1, transpose_2) while -> decoder block
...@@ -149,10 +146,9 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -149,10 +146,9 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk, nullptr, -1); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk, nullptr, -1);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul // MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, transpose_2); auto* matmul_qkv = layers.matmul_v2(softmax_qk, transpose_2);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
...@@ -165,9 +161,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -165,9 +161,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto* linear_eltadd_out = auto* linear_eltadd_out =
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2); layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
auto* dropout_qkv = auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm // FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
...@@ -190,9 +184,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -190,9 +184,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
auto* ffn_eltadd1_out = auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2); layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add layers.elementwise_add(attention_out, ffn_eltadd1_out);
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
...@@ -210,12 +202,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) { ...@@ -210,12 +202,12 @@ TEST(FusedMultiTransformerEncoderPass, basic) {
int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); int num_fused_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(num_nodes_before, PADDLE_ENFORCE_EQ(num_nodes_before,
num_nodes_after + 68, num_nodes_after + 56,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_pass, The " "After the fused_multi_transformer_encoder_pass, The "
"node num in graph " "node num in graph "
"should be %d, but the result is %d", "should be %d, but the result is %d",
num_nodes_before - 68, num_nodes_before - 56,
num_nodes_after)); num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1, 1,
...@@ -246,13 +238,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -246,13 +238,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
// (split_q, split_k) matmul -> matmul_qk // (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk // (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv // (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv // (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear // (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) elementwise_add -> eltadd_linear // (matmul_linear) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out // (eltadd_out) elementwise_add -> attention_out
// //
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
...@@ -261,8 +251,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -261,8 +251,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
// (ffn_eltadd0) gelu -> ffn_gelu // (ffn_eltadd0) gelu -> ffn_gelu
// (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1 // (ffn_matmul1, ffn_bias1) elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
// //
// (transpose_1, transpose_2) while -> decoder block // (transpose_1, transpose_2) while -> decoder block
...@@ -304,10 +293,9 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -304,10 +293,9 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul // MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, split_v); auto* matmul_qkv = layers.matmul_v2(softmax_qk, split_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
...@@ -320,9 +308,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -320,9 +308,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
auto* linear_eltadd_out = auto* linear_eltadd_out =
layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2); layers.elementwise_add(linear_matmut_out, bias_l, nullptr, 2);
auto* dropout_qkv = auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm // FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
...@@ -345,9 +331,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -345,9 +331,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
auto* ffn_eltadd1_out = auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2); layers.elementwise_add(ffn_matmul1_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add layers.elementwise_add(attention_out, ffn_eltadd1_out);
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
...@@ -366,11 +350,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -366,11 +350,11 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
num_nodes_before, num_nodes_before,
num_nodes_after + 56, num_nodes_after + 44,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_fuse_qkv_pass, " "After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d", "The node num in graph should be %d, but the result is %d",
num_nodes_before - 56, num_nodes_before - 44,
num_nodes_after)); num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1, 1,
...@@ -402,14 +386,12 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -402,14 +386,12 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
// (split_q, split_k) matmul -> matmul_qk // (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk // (eltadd_qk) softmax -> softmax_qk
// (softmax_qk) dropout -> dropout_qk // (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (dropout_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv // (matmul_qkv) transpose -> transpose_qkv
// (transpose_qkv) reshape -> reshape_qkv // (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) matmul_v2 -> matmul_linear // (reshape_qkv) matmul_v2 -> matmul_linear
// (matmul_linear) c_all_reduce -> c_all_reduce_out // (matmul_linear) c_all_reduce -> c_all_reduce_out
// (c_all_reduce_out) elementwise_add -> eltadd_linear // (c_all_reduce_out) elementwise_add -> eltadd_linear
// (eltadd_linear) dropout -> dropout_linear
// (eltadd_out) elementwise_add -> attention_out // (eltadd_out) elementwise_add -> attention_out
// //
// (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out // (attention_out, scale, bias) layer_norm -> ffn_layer_norm_out
...@@ -420,8 +402,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -420,8 +402,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
// (ffn_gelu) matmul_v2 -> ffn_matmul1 // (ffn_gelu) matmul_v2 -> ffn_matmul1
// (ffn_matmul1) c_all_reduce -> ffn_c_all_reduce_out // (ffn_matmul1) c_all_reduce -> ffn_c_all_reduce_out
// (ffn_c_all_reduce_out, ffn_bias1)elementwise_add -> ffn_eltadd1 // (ffn_c_all_reduce_out, ffn_bias1)elementwise_add -> ffn_eltadd1
// (ffn_eltadd1) dropout -> ffn_dropout // (attention_out, ffn_eltadd1) elementwise_add -> ffn_output
// (attention_out, ffn_dropout) elementwise_add -> ffn_output
// //
// (transpose_1, transpose_2) while -> decoder block // (transpose_1, transpose_2) while -> decoder block
...@@ -464,10 +445,9 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -464,10 +445,9 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1); auto* softmax_qk = layers.softmax(elementwise_qk, -1);
auto* dropout_qk = layers.dropout(softmax_qk, 0.1, "upscale_in_train");
// MHA: QKV matmul // MHA: QKV matmul
auto* matmul_qkv = layers.matmul_v2(dropout_qk, split_v); auto* matmul_qkv = layers.matmul_v2(softmax_qk, split_v);
auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true);
auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true); auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 1024}, true);
...@@ -481,9 +461,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -481,9 +461,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
auto* linear_eltadd_out = auto* linear_eltadd_out =
layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2); layers.elementwise_add(c_allreduce_out, bias_l, nullptr, 2);
auto* dropout_qkv = auto* attention_out = layers.elementwise_add(x, linear_eltadd_out);
layers.dropout(linear_eltadd_out, 0.1, "upscale_in_train");
auto* attention_out = layers.elementwise_add(x, dropout_qkv);
// FFN: pre LayerNorm // FFN: pre LayerNorm
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
...@@ -508,9 +486,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -508,9 +486,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
auto* ffn_eltadd1_out = auto* ffn_eltadd1_out =
layers.elementwise_add(ffn_allreduce_out, ffn_bias1, nullptr, 2); layers.elementwise_add(ffn_allreduce_out, ffn_bias1, nullptr, 2);
// FFN: dropout -> elementwise_add layers.elementwise_add(attention_out, ffn_eltadd1_out);
auto* ffn_dropout = layers.dropout(ffn_eltadd1_out, 0.1, "upscale_in_train");
layers.elementwise_add(attention_out, ffn_dropout);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
...@@ -531,11 +507,11 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { ...@@ -531,11 +507,11 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
num_nodes_before, num_nodes_before,
num_nodes_after + 64, num_nodes_after + 52,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"After the fused_multi_transformer_encoder_fuse_qkv_pass, " "After the fused_multi_transformer_encoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d", "The node num in graph should be %d, but the result is %d",
num_nodes_before - 64, num_nodes_before - 52,
num_nodes_after)); num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1, 1,
......
...@@ -39,6 +39,7 @@ namespace ir { ...@@ -39,6 +39,7 @@ namespace ir {
static const char kParamScopeAttr[] = "__param_scope__"; static const char kParamScopeAttr[] = "__param_scope__";
static const std::vector<std::string> support_subgraph_passes = { static const std::vector<std::string> support_subgraph_passes = {
"simplify_with_basic_ops_pass",
"fused_multi_transformer_encoder_pass", "fused_multi_transformer_encoder_pass",
"fused_multi_transformer_decoder_pass", "fused_multi_transformer_decoder_pass",
"fused_multi_transformer_encoder_fuse_qkv_pass", "fused_multi_transformer_encoder_fuse_qkv_pass",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册