diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 5d31b443c1b91a1478cebd670521f25cf0a7fd51..ef89dbb3ffe6e680473713376595a8959f52586f 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -383,6 +383,7 @@ set(IR_PASS_DEPS fix_op_run_order_pass fuse_gemm_epilogue_pass fused_attention_pass + fused_feedforward_pass delete_dropout_op_pass) if(WITH_CINN) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 47a262ea3576dc152ca0d2afe45289188f497f11..395da4b0a092861cdfc63422672c38898fea082d 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -210,6 +210,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { AppendPass("fuse_sgd_op_pass"); AppendPass("fuse_momentum_op_pass"); } +#ifdef PADDLE_WITH_CUDA + AppendPassWithCheck(strategy_.fused_feedforward_, "fused_feedforward_pass"); +#endif } void SetCollectiveContext() const { @@ -529,6 +532,9 @@ USE_PASS(fused_attention_pass); #ifdef PADDLE_WITH_CINN USE_PASS(build_cinn_pass); #endif +#ifdef PADDLE_WITH_CUDA +USE_PASS(fused_feedforward_pass); +#endif #ifdef PADDLE_WITH_MKLDNN USE_PASS(mkldnn_placement_pass); #endif diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 29e390bf0fd772cd46ecb63ee96be478cbdba4e5..1cd15746ddd3230a926a6f39728cb8765d2e368d 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -131,6 +131,8 @@ struct BuildStrategy { bool fuse_gemm_epilogue_{false}; // Fused multi head attention bool fused_attention_{false}; + // Fused feed forward + bool fused_feedforward_{false}; // mkldnn_enabled_op_types specify the operator type list to // use MKLDNN acceleration. It is null in default, means @@ -264,6 +266,7 @@ inline std::ostream &operator<<(std::ostream &os, os << "sync_batch_norm_: " << strategy.sync_batch_norm_ << std::endl; os << "fuse_gemm_epilogue_: " << strategy.fuse_gemm_epilogue_ << std::endl; os << "fused_attention_: " << strategy.fused_attention_ << std::endl; + os << "fused_feedforward_: " << strategy.fused_feedforward_ << std::endl; os << "mkldnn_enabled_op_types_: "; for (auto str : strategy.mkldnn_enabled_op_types_) { os << str << ", "; diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 1b329a0d65b84453eee478d4d7c98f8eeb9f0bb3..d9182c488f23f9839eb4b13e5e0c9a83ededc36b 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -126,6 +126,7 @@ message BuildStrategy { optional bool fuse_gemm_epilogue = 16 [ default = false ]; optional string debug_graphviz_path = 17; optional bool fused_attention = 18 [ default = false]; + optional bool fused_feedforward = 19 [ default = false]; } message ExecutionStrategy { diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 01627a33633cdeb21a9f86e4defbbf6c3a501e1a..b619bef90102793bd6152ca8fd50dce4fe49c6c9 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -264,6 +264,10 @@ cc_library( fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector) +cc_library( + fused_feedforward_pass + SRCS fused_feedforward_pass.cc + DEPS pass graph_pattern_detector) set(GLOB_PASS_LIB ${INFER_IR_PASSES} diff --git a/paddle/fluid/framework/ir/fused_feedforward_pass.cc b/paddle/fluid/framework/ir/fused_feedforward_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..0ae23adfc425c8e64d907f7432b386696db48a8f --- /dev/null +++ b/paddle/fluid/framework/ir/fused_feedforward_pass.cc @@ -0,0 +1,760 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/fused_feedforward_pass.h" + +#include +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +void FusedFeedForwardPass::ApplyImpl(ir::Graph *graph) const { + FusePassBase::Init(scope_name, graph); + for (auto use_mp : std::vector({true, false})) { + for (auto pre_layer_norm : std::vector({true, false})) { + for (auto add_residual : std::vector({true, false})) { + for (auto use_dropout_1 : std::vector({true, false})) { + for (auto use_dropout_2 : std::vector({true, false})) { + // pre_layer_norm and add_residual can't both be false! + if (!pre_layer_norm && !add_residual) continue; + // use_dropout_1 and use_dropout_2 can't both be false! + if (!use_dropout_1 && !use_dropout_2) continue; + Cache dropout_nodes_map; + graph = FusedFeedForwardFwd(graph, + use_mp, + pre_layer_norm, + add_residual, + use_dropout_1, + use_dropout_2, + &dropout_nodes_map); + graph = FusedFeedForwardBwd(graph, + use_mp, + pre_layer_norm, + add_residual, + use_dropout_1, + use_dropout_2, + &dropout_nodes_map); + } + } + } + } + } +} + +ir::Graph *FusedFeedForwardPass::FusedFeedForwardFwd( + ir::Graph *graph, + bool use_mp, + bool pre_layer_norm, + bool add_residual, + bool use_dropout_1, + bool use_dropout_2, + Cache *dropout_nodes_map) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + const std::string scope_name("fused_feed_forward_fwd_pattern"); + GraphPatternDetector gpd; + auto *x = gpd.mutable_pattern() + ->NewNode(patterns::PDNodeName(scope_name, "x")) + ->AsInput(); + if (pre_layer_norm) { + x->assert_is_op_input("layer_norm", "X"); + } else { + x->assert_is_op_input("matmul_v2", "X"); + } + + // 1. layer_norm -> linear1 -> activation -> dropout1 -> linear2 -> dropout2 + // -> residual_add (pre_layer_norm) + // 2. linear1 -> activation -> dropout1 -> linear2 -> dropout2 -> residual_add + // -> layer_norm (post_layer_norm) + // other cases: may delete mp, residual_add, dropout1, dropout2 operators + patterns::FusedFeedForwardFwd fused_feedforward_pattern(gpd.mutable_pattern(), + scope_name); + std::unordered_set act_types = {"gelu", "relu"}; + + VLOG(4) << "Fused Feedforward forward pass." + << " pre_layer_norm: " << pre_layer_norm + << ", add_residual: " << add_residual + << ", use_dropout_1: " << use_dropout_1 + << ", use_dropout_2: " << use_dropout_2; + + fused_feedforward_pattern(x, + act_types, + use_mp, + pre_layer_norm, + add_residual, + use_dropout_1, + use_dropout_2); + + int found_fused_feedforward_fwd_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "handle feed_forward forward fusion"; + + // LayerNorm + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_op, layer_norm_op, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_bias, layer_norm_bias, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_scale, layer_norm_scale, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_out, layer_norm_out, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_mean, layer_norm_mean, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_variance, layer_norm_variance, fused_feedforward_pattern); + + // Linear1 + GET_IR_NODE_FROM_SUBGRAPH( + matmul_op_1, matmul_op_1, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_w_1, matmul_w_1, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_out_1, matmul_out_1, fused_feedforward_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_op_1, ele_add_op_1, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_bias_1, ele_add_bias_1, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_out_1, ele_add_out_1, fused_feedforward_pattern); + + // Activation + GET_IR_NODE_FROM_SUBGRAPH(act_op, act_op, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, fused_feedforward_pattern); + // Linear2 + GET_IR_NODE_FROM_SUBGRAPH( + matmul_op_2, matmul_op_2, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_w_2, matmul_w_2, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_out_2, matmul_out_2, fused_feedforward_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_op_2, ele_add_op_2, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_bias_2, ele_add_bias_2, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_out_2, ele_add_out_2, fused_feedforward_pattern); + + if (use_dropout_1 && use_dropout_2) { + GET_IR_NODE_FROM_SUBGRAPH( + dropout_op_1, dropout_op_1, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + dropout_op_2, dropout_op_2, fused_feedforward_pattern); + if (PADDLE_GET_CONST(bool, dropout_op_1->Op()->GetAttr("is_test")) != + PADDLE_GET_CONST(bool, dropout_op_2->Op()->GetAttr("is_test"))) { + LOG(WARNING) << "Dropout 1 and dropout 2 attribute is_test set " + "different values. " + << "Skip fused_feedforward pattern replacement."; + return; + } + } + + OpDesc fused_feedforward_op_desc(layer_norm_op->Op()->Block()); + + fused_feedforward_op_desc.SetType("fused_feedforward"); + fused_feedforward_op_desc.SetInput("X", {subgraph.at(x)->Name()}); + fused_feedforward_op_desc.SetInput("Linear1Weight", {matmul_w_1->Name()}); + fused_feedforward_op_desc.SetInput("Linear1Bias", {ele_add_bias_1->Name()}); + fused_feedforward_op_desc.SetInput("Linear2Weight", {matmul_w_2->Name()}); + fused_feedforward_op_desc.SetInput("Linear2Bias", {ele_add_bias_2->Name()}); + if (pre_layer_norm) { + fused_feedforward_op_desc.SetInput("Ln1Scale", + {layer_norm_scale->Name()}); + fused_feedforward_op_desc.SetInput("Ln1Bias", {layer_norm_bias->Name()}); + fused_feedforward_op_desc.SetOutput("Ln1Mean", {layer_norm_mean->Name()}); + fused_feedforward_op_desc.SetOutput("Ln1Variance", + {layer_norm_variance->Name()}); + fused_feedforward_op_desc.SetOutput("Ln1Out", {layer_norm_out->Name()}); + fused_feedforward_op_desc.SetAttr( + "ln1_epsilon", layer_norm_op->Op()->GetAttr("epsilon")); + if (!add_residual) { + if (use_dropout_2) { + GET_IR_NODE_FROM_SUBGRAPH( + dropout_out_2, dropout_out_2, fused_feedforward_pattern); + fused_feedforward_op_desc.SetOutput("Out", {dropout_out_2->Name()}); + } else { + fused_feedforward_op_desc.SetOutput("Out", {ele_add_out_2->Name()}); + } + } else { + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_out_3, ele_add_out_3, fused_feedforward_pattern); + fused_feedforward_op_desc.SetOutput("Out", {ele_add_out_3->Name()}); + } + } else { + fused_feedforward_op_desc.SetInput("Ln2Scale", + {layer_norm_scale->Name()}); + fused_feedforward_op_desc.SetInput("Ln2Bias", {layer_norm_bias->Name()}); + fused_feedforward_op_desc.SetOutput("Ln2Mean", {layer_norm_mean->Name()}); + fused_feedforward_op_desc.SetOutput("Ln2Variance", + {layer_norm_variance->Name()}); + fused_feedforward_op_desc.SetAttr( + "ln2_epsilon", layer_norm_op->Op()->GetAttr("epsilon")); + fused_feedforward_op_desc.SetOutput("Out", {layer_norm_out->Name()}); + } + + bool is_test = false; + DropoutNode record; + if (use_dropout_1) { + // Dropout1 + GET_IR_NODE_FROM_SUBGRAPH( + dropout_op_1, dropout_op_1, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + dropout_mask_1, dropout_mask_1, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + dropout_out_1, dropout_out_1, fused_feedforward_pattern); + record.dropout_mask_node_1 = dropout_mask_1; + record.dropout_out_node_1 = dropout_out_1; + fused_feedforward_op_desc.SetOutput("Dropout1Mask", + {dropout_mask_1->Name()}); + fused_feedforward_op_desc.SetOutput("Dropout1Out", + {dropout_out_1->Name()}); + fused_feedforward_op_desc.SetAttr( + "dropout1_rate", dropout_op_1->Op()->GetAttr("dropout_prob")); + fused_feedforward_op_desc.SetAttr( + "dropout1_implementation", + dropout_op_1->Op()->GetAttr("dropout_implementation")); + is_test = PADDLE_GET_CONST(bool, dropout_op_1->Op()->GetAttr("is_test")); + } else { + fused_feedforward_op_desc.SetAttr("dropout1_rate", 0.0f); + VarDesc dropout_out_desc_1( + patterns::PDNodeName(scope_name, "dropout_out_1")); + dropout_out_desc_1.SetShape(ele_add_out_1->Var()->GetShape()); + dropout_out_desc_1.SetDataType(ele_add_out_1->Var()->GetDataType()); + dropout_out_desc_1.SetLoDLevel(ele_add_out_1->Var()->GetLoDLevel()); + dropout_out_desc_1.SetStopGradient(static_cast(true)); + record.dropout_out_node_1 = g->CreateVarNode(&dropout_out_desc_1); + fused_feedforward_op_desc.SetOutput("Dropout1Out", + {record.dropout_out_node_1->Name()}); + + VarDesc dropout_mask_desc_1( + patterns::PDNodeName(scope_name, "dropout_mask_1")); + dropout_mask_desc_1.SetShape(ele_add_out_1->Var()->GetShape()); + dropout_mask_desc_1.SetDataType(proto::VarType::UINT8); + dropout_mask_desc_1.SetLoDLevel(ele_add_out_1->Var()->GetLoDLevel()); + dropout_mask_desc_1.SetStopGradient(static_cast(true)); + // Tranfer to backward operator. + record.dropout_mask_node_1 = g->CreateVarNode(&dropout_mask_desc_1); + fused_feedforward_op_desc.SetOutput("Dropout1Mask", + {record.dropout_mask_node_1->Name()}); + } + + if (use_dropout_2) { + // Dropout2 + GET_IR_NODE_FROM_SUBGRAPH( + dropout_op_2, dropout_op_2, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + dropout_mask_2, dropout_mask_2, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + dropout_out_2, dropout_out_2, fused_feedforward_pattern); + record.dropout_out_node_2 = dropout_out_2; + record.dropout_mask_node_2 = dropout_mask_2; + fused_feedforward_op_desc.SetOutput("Dropout2Out", + {dropout_out_2->Name()}); + fused_feedforward_op_desc.SetOutput("Dropout2Mask", + {dropout_mask_2->Name()}); + fused_feedforward_op_desc.SetAttr( + "dropout2_rate", dropout_op_2->Op()->GetAttr("dropout_prob")); + fused_feedforward_op_desc.SetAttr( + "dropout2_implementation", + dropout_op_2->Op()->GetAttr("dropout_implementation")); + is_test = PADDLE_GET_CONST(bool, dropout_op_2->Op()->GetAttr("is_test")); + } else { + fused_feedforward_op_desc.SetAttr("dropout2_rate", 0.0f); + VarDesc dropout_out_desc_2( + patterns::PDNodeName(scope_name, "dropout_out_2")); + dropout_out_desc_2.SetShape(ele_add_out_2->Var()->GetShape()); + dropout_out_desc_2.SetDataType(ele_add_out_2->Var()->GetDataType()); + dropout_out_desc_2.SetLoDLevel(ele_add_out_2->Var()->GetLoDLevel()); + dropout_out_desc_2.SetStopGradient(static_cast(true)); + record.dropout_out_node_2 = g->CreateVarNode(&dropout_out_desc_2); + fused_feedforward_op_desc.SetOutput("Dropout2Out", + {record.dropout_out_node_2->Name()}); + + VarDesc dropout_mask_desc_2( + patterns::PDNodeName(scope_name, "dropout_mask_2")); + dropout_mask_desc_2.SetShape(ele_add_out_2->Var()->GetShape()); + dropout_mask_desc_2.SetDataType(proto::VarType::UINT8); + dropout_mask_desc_2.SetLoDLevel(ele_add_out_2->Var()->GetLoDLevel()); + dropout_mask_desc_2.SetStopGradient(static_cast(true)); + // Transmit to backward operator. + record.dropout_mask_node_2 = g->CreateVarNode(&dropout_mask_desc_2); + fused_feedforward_op_desc.SetOutput("Dropout2Mask", + {record.dropout_mask_node_2->Name()}); + } + // Transmit to backward operator. + dropout_nodes_map->insert(std::make_pair(matmul_w_1, record)); + + fused_feedforward_op_desc.SetOutput("Linear1Out", {ele_add_out_1->Name()}); + fused_feedforward_op_desc.SetAttr("pre_layer_norm", pre_layer_norm); + fused_feedforward_op_desc.SetAttr("act_method", act_op->Op()->Type()); + + if (!use_dropout_1 && !use_dropout_2) { + is_test = true; + } + fused_feedforward_op_desc.SetAttr("is_test", is_test); + // These attributes set default value + fused_feedforward_op_desc.SetAttr("dropout1_fix_seed", false); + fused_feedforward_op_desc.SetAttr("dropout2_fix_seed", false); + fused_feedforward_op_desc.SetAttr("dropout1_seed", 0); + fused_feedforward_op_desc.SetAttr("dropout2_seed", 0); + fused_feedforward_op_desc.SetAttr("add_residual", add_residual); + int ring_id = -1; + if (use_mp) { + GET_IR_NODE_FROM_SUBGRAPH( + c_allreduce_sum_op, c_allreduce_sum_op, fused_feedforward_pattern); + ring_id = + PADDLE_GET_CONST(int, c_allreduce_sum_op->Op()->GetAttr("ring_id")); + } + fused_feedforward_op_desc.SetAttr("ring_id", ring_id); + + auto fused_feedforward_node = g->CreateOpNode(&fused_feedforward_op_desc); + + IR_NODE_LINK_TO(subgraph.at(x), fused_feedforward_node); + IR_NODE_LINK_TO(matmul_w_1, fused_feedforward_node); + IR_NODE_LINK_TO(ele_add_bias_1, fused_feedforward_node); + IR_NODE_LINK_TO(matmul_w_2, fused_feedforward_node); + IR_NODE_LINK_TO(ele_add_bias_2, fused_feedforward_node); + IR_NODE_LINK_TO(layer_norm_scale, fused_feedforward_node); + IR_NODE_LINK_TO(layer_norm_bias, fused_feedforward_node); + IR_NODE_LINK_TO(fused_feedforward_node, layer_norm_out); + IR_NODE_LINK_TO(fused_feedforward_node, layer_norm_mean); + IR_NODE_LINK_TO(fused_feedforward_node, layer_norm_variance); + IR_NODE_LINK_TO(fused_feedforward_node, record.dropout_mask_node_1); + IR_NODE_LINK_TO(fused_feedforward_node, record.dropout_out_node_1); + IR_NODE_LINK_TO(fused_feedforward_node, record.dropout_mask_node_2); + IR_NODE_LINK_TO(fused_feedforward_node, record.dropout_out_node_2); + IR_NODE_LINK_TO(fused_feedforward_node, ele_add_out_1); + if (!pre_layer_norm) { + IR_NODE_LINK_TO(fused_feedforward_node, layer_norm_out); + } else { + if (add_residual) { + // Residual Add, dispensable + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_out_3, ele_add_out_3, fused_feedforward_pattern); + IR_NODE_LINK_TO(fused_feedforward_node, ele_add_out_3); + } else { + if (!use_dropout_2) { + IR_NODE_LINK_TO(fused_feedforward_node, ele_add_out_2); + } + } + } + + std::unordered_set nodes_to_remove = {layer_norm_op, + matmul_op_1, + ele_add_op_1, + act_op, + matmul_op_2, + ele_add_op_2}; + if (use_mp) { + GET_IR_NODE_FROM_SUBGRAPH( + c_identity_op, c_identity_op, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + c_allreduce_sum_op, c_allreduce_sum_op, fused_feedforward_pattern); + nodes_to_remove.insert(c_identity_op); + nodes_to_remove.insert(c_allreduce_sum_op); + } + if (use_dropout_1) { + GET_IR_NODE_FROM_SUBGRAPH( + dropout_op_1, dropout_op_1, fused_feedforward_pattern); + nodes_to_remove.insert(dropout_op_1); + } + if (use_dropout_2) { + GET_IR_NODE_FROM_SUBGRAPH( + dropout_op_2, dropout_op_2, fused_feedforward_pattern); + nodes_to_remove.insert(dropout_op_2); + } + if (add_residual) { + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_op_3, ele_add_op_3, fused_feedforward_pattern); + nodes_to_remove.insert(ele_add_op_3); + } + GraphSafeRemoveNodes(g, nodes_to_remove); + found_fused_feedforward_fwd_count++; + VLOG(4) << "After remove nodes."; + }; + + gpd(graph, handler); + AddStatis(found_fused_feedforward_fwd_count); + return graph; +} + +ir::Graph *FusedFeedForwardPass::FusedFeedForwardBwd( + ir::Graph *graph, + bool use_mp, + bool pre_layer_norm, + bool add_residual, + bool use_dropout_1, + bool use_dropout_2, + Cache *dropout_nodes_map) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + const std::string scope_name("fused_feed_forward_bwd_pattern"); + + // 1. residual_add_grad -> dropout2_grad -> linear2_grad -> dropout1_grad -> + // activation_grad -> linear1_grad -> layer_norm_grad + // 2. layer_norm_grad -> residual_add_grad -> dropout2_grad -> linear2_grad -> + // dropout1_grad -> activation_grad -> linear1_grad + // other cases: may delete mp, residual_add_grad, dropout1_grad, dropout2_grad + // operators + GraphPatternDetector gpd; + + auto *x_grad = gpd.mutable_pattern() + ->NewNode(patterns::PDNodeName(scope_name, "x_grad")) + ->AsInput(); + + patterns::FusedFeedForwardBwd fused_feedforward_pattern(gpd.mutable_pattern(), + scope_name); + std::unordered_set act_grad_types = {"gelu_grad", "relu_grad"}; + fused_feedforward_pattern(x_grad, + act_grad_types, + use_mp, + pre_layer_norm, + add_residual, + use_dropout_1, + use_dropout_2); + + VLOG(4) << "Fused Feedforward backward pass." + << " pre_layer_norm: " << pre_layer_norm + << ", add_residual: " << add_residual + << ", use_dropout_1: " << use_dropout_1 + << ", use_dropout_2: " << use_dropout_2; + + int found_fused_feedforward_bwd_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "handle feed_forward backward fusion"; + + // LayerNorm Grad + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_op_grad, layer_norm_op_grad, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_in, layer_norm_in, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_mean, layer_norm_mean, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_variance, layer_norm_variance, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_scale, layer_norm_scale, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_bias, layer_norm_bias, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_in_grad, layer_norm_in_grad, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale_grad, + layer_norm_scale_grad, + fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_bias_grad, layer_norm_bias_grad, fused_feedforward_pattern); + // Linear Grad 1 + GET_IR_NODE_FROM_SUBGRAPH( + matmul_op_grad_1, matmul_op_grad_1, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_in_1, matmul_in_1, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_w_1, matmul_w_1, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_in_grad_1, matmul_in_grad_1, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_w_grad_1, matmul_w_grad_1, fused_feedforward_pattern); + + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_op_grad_1, ele_add_op_grad_1, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_in_1, ele_add_in_1, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_bias_1, ele_add_bias_1, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_in_grad_1, ele_add_in_grad_1, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_bias_grad_1, ele_add_bias_grad_1, fused_feedforward_pattern); + // Activation Grad + GET_IR_NODE_FROM_SUBGRAPH( + act_op_grad, act_op_grad, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act_in, act_in, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + act_in_grad, act_in_grad, fused_feedforward_pattern); + // Linear Grad 2 + GET_IR_NODE_FROM_SUBGRAPH( + matmul_op_grad_2, matmul_op_grad_2, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_in_2, matmul_in_2, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_w_2, matmul_w_2, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_in_grad_2, matmul_in_grad_2, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + matmul_w_grad_2, matmul_w_grad_2, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_op_grad_2, ele_add_op_grad_2, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_in_2, ele_add_in_2, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_bias_2, ele_add_bias_2, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_in_grad_2, ele_add_in_grad_2, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_bias_grad_2, ele_add_bias_grad_2, fused_feedforward_pattern); + + auto record = (*dropout_nodes_map)[matmul_w_1]; + if (use_dropout_1) { + GET_IR_NODE_FROM_SUBGRAPH( + dropout_op_grad_1, dropout_op_grad_1, fused_feedforward_pattern); + if (PADDLE_GET_CONST(bool, dropout_op_grad_1->Op()->GetAttr("is_test"))) { + LOG(WARNING) << "Dropout_grad 1 attribute is_test should be set false." + << " Skip fused_feedforward_grad pattern replacement"; + return; + } + } else { + if (record.dropout_mask_node_1 == nullptr || + record.dropout_out_node_1 == nullptr) { + LOG(WARNING) + << "Dropout_grad 1 has no mask/out input from forward pass." + << " Skip fused_feedforward_grad pattern replacement"; + return; + } + } + + if (use_dropout_2) { + GET_IR_NODE_FROM_SUBGRAPH( + dropout_op_grad_2, dropout_op_grad_2, fused_feedforward_pattern); + if (PADDLE_GET_CONST(bool, dropout_op_grad_2->Op()->GetAttr("is_test"))) { + LOG(WARNING) << "Dropout_grad 2 attribute is_test should be set false." + << " Skip fused_feedforward_grad pattern replacement"; + return; + } + } else { + if (record.dropout_mask_node_2 == nullptr) { + LOG(WARNING) << "Dropout_grad 2 has no mask input from forward pass." + << " Skip fused_feedforward_grad pattern replacement"; + return; + } + } + + OpDesc fused_feedforward_op_desc(layer_norm_op_grad->Op()->Block()); + + fused_feedforward_op_desc.SetType("fused_feedforward_grad"); + fused_feedforward_op_desc.SetInput(framework::GradVarName("Out"), + {subgraph.at(x_grad)->Name()}); + fused_feedforward_op_desc.SetInput( + "X", {pre_layer_norm ? layer_norm_in->Name() : matmul_in_1->Name()}); + fused_feedforward_op_desc.SetInput("Linear1Weight", {matmul_w_1->Name()}); + fused_feedforward_op_desc.SetInput("Linear1Bias", {ele_add_bias_1->Name()}); + fused_feedforward_op_desc.SetInput("Linear2Weight", {matmul_w_2->Name()}); + fused_feedforward_op_desc.SetInput("Linear2Bias", {ele_add_bias_2->Name()}); + fused_feedforward_op_desc.SetInput("Linear1Out", {act_in->Name()}); + fused_feedforward_op_desc.SetInput("Dropout1Out", + {record.dropout_out_node_1->Name()}); + fused_feedforward_op_desc.SetInput("Dropout1Mask", + {record.dropout_mask_node_1->Name()}); + fused_feedforward_op_desc.SetInput("Dropout2Mask", + {record.dropout_mask_node_2->Name()}); + + fused_feedforward_op_desc.SetOutput(GradVarName("Linear1Weight"), + {matmul_w_grad_1->Name()}); + fused_feedforward_op_desc.SetOutput(GradVarName("Linear1Bias"), + {ele_add_bias_grad_1->Name()}); + fused_feedforward_op_desc.SetOutput(GradVarName("Linear2Weight"), + {matmul_w_grad_2->Name()}); + fused_feedforward_op_desc.SetOutput(GradVarName("Linear2Bias"), + {ele_add_bias_grad_2->Name()}); + + fused_feedforward_op_desc.SetAttr("pre_layer_norm", pre_layer_norm); + fused_feedforward_op_desc.SetAttr( + "ln1_epsilon", layer_norm_op_grad->Op()->GetAttr("epsilon")); + fused_feedforward_op_desc.SetAttr( + "ln2_epsilon", layer_norm_op_grad->Op()->GetAttr("epsilon")); + fused_feedforward_op_desc.SetAttr("act_method", + act_op_grad->Op()->Type().substr(0, 4)); + fused_feedforward_op_desc.SetAttr("add_residual", add_residual); + // These attributes set default value + fused_feedforward_op_desc.SetAttr("is_test", false); + fused_feedforward_op_desc.SetAttr("dropout1_fix_seed", false); + fused_feedforward_op_desc.SetAttr("dropout2_fix_seed", false); + fused_feedforward_op_desc.SetAttr("dropout1_seed", 0); + fused_feedforward_op_desc.SetAttr("dropout2_seed", 0); + int ring_id = -1; + if (use_mp) { + GET_IR_NODE_FROM_SUBGRAPH( + c_allreduce_sum_op, c_allreduce_sum_op, fused_feedforward_pattern); + ring_id = + PADDLE_GET_CONST(int, c_allreduce_sum_op->Op()->GetAttr("ring_id")); + } + fused_feedforward_op_desc.SetAttr("ring_id", ring_id); + + if (pre_layer_norm) { + fused_feedforward_op_desc.SetInput("Ln1Scale", + {layer_norm_scale->Name()}); + fused_feedforward_op_desc.SetInput("Ln1Bias", {layer_norm_bias->Name()}); + fused_feedforward_op_desc.SetInput("Ln1Out", {matmul_in_1->Name()}); + fused_feedforward_op_desc.SetInput("Ln1Mean", {layer_norm_mean->Name()}); + fused_feedforward_op_desc.SetInput("Ln1Variance", + {layer_norm_variance->Name()}); + fused_feedforward_op_desc.SetOutput(GradVarName("Ln1Scale"), + {layer_norm_scale_grad->Name()}); + fused_feedforward_op_desc.SetOutput(GradVarName("Ln1Bias"), + {layer_norm_bias_grad->Name()}); + } else { + fused_feedforward_op_desc.SetInput("Ln2Scale", + {layer_norm_scale->Name()}); + fused_feedforward_op_desc.SetInput("Ln2Bias", {layer_norm_bias->Name()}); + fused_feedforward_op_desc.SetInput("Ln2Mean", {layer_norm_mean->Name()}); + fused_feedforward_op_desc.SetInput("Ln2Variance", + {layer_norm_variance->Name()}); + // Special + fused_feedforward_op_desc.SetInput("Dropout2Out", + {record.dropout_out_node_2->Name()}); + fused_feedforward_op_desc.SetOutput(GradVarName("Ln2Scale"), + {layer_norm_scale_grad->Name()}); + fused_feedforward_op_desc.SetOutput(GradVarName("Ln2Bias"), + {layer_norm_bias_grad->Name()}); + } + + if (use_dropout_1) { + // Dropout Grad 1 + GET_IR_NODE_FROM_SUBGRAPH( + dropout_op_grad_1, dropout_op_grad_1, fused_feedforward_pattern); + fused_feedforward_op_desc.SetAttr( + "dropout1_rate", dropout_op_grad_1->Op()->GetAttr("dropout_prob")); + fused_feedforward_op_desc.SetAttr( + "dropout1_implementation", + dropout_op_grad_1->Op()->GetAttr("dropout_implementation")); + } else { + fused_feedforward_op_desc.SetAttr("dropout1_rate", 0.0f); + fused_feedforward_op_desc.SetAttr( + "dropout1_implementation", + static_cast("upscale_in_train")); + } + + if (use_dropout_2) { + // Dropout Grad 2 + GET_IR_NODE_FROM_SUBGRAPH( + dropout_op_grad_2, dropout_op_grad_2, fused_feedforward_pattern); + fused_feedforward_op_desc.SetAttr( + "dropout2_rate", dropout_op_grad_2->Op()->GetAttr("dropout_prob")); + fused_feedforward_op_desc.SetAttr( + "dropout2_implementation", + dropout_op_grad_2->Op()->GetAttr("dropout_implementation")); + } else { + fused_feedforward_op_desc.SetAttr("dropout2_rate", 0.0f); + fused_feedforward_op_desc.SetAttr( + "dropout2_implementation", + static_cast("upscale_in_train")); + } + + if (add_residual) { + GET_IR_NODE_FROM_SUBGRAPH(sum_out, sum_out, fused_feedforward_pattern); + fused_feedforward_op_desc.SetOutput(GradVarName("X"), {sum_out->Name()}); + } else { + if (pre_layer_norm) { + fused_feedforward_op_desc.SetOutput(GradVarName("X"), + {layer_norm_in_grad->Name()}); + } else { + fused_feedforward_op_desc.SetOutput(GradVarName("X"), + {matmul_in_grad_1->Name()}); + } + } + + auto fused_feedforward_node = g->CreateOpNode(&fused_feedforward_op_desc); + IR_NODE_LINK_TO(subgraph.at(x_grad), fused_feedforward_node); + IR_NODE_LINK_TO(matmul_w_1, fused_feedforward_node); + IR_NODE_LINK_TO(ele_add_bias_1, fused_feedforward_node); + IR_NODE_LINK_TO(matmul_w_2, fused_feedforward_node); + IR_NODE_LINK_TO(ele_add_bias_2, fused_feedforward_node); + IR_NODE_LINK_TO(record.dropout_mask_node_1, fused_feedforward_node); + IR_NODE_LINK_TO(record.dropout_mask_node_2, fused_feedforward_node); + IR_NODE_LINK_TO(act_in, fused_feedforward_node); + IR_NODE_LINK_TO(record.dropout_out_node_1, fused_feedforward_node); + IR_NODE_LINK_TO(layer_norm_scale, fused_feedforward_node); + IR_NODE_LINK_TO(layer_norm_bias, fused_feedforward_node); + IR_NODE_LINK_TO(layer_norm_mean, fused_feedforward_node); + IR_NODE_LINK_TO(layer_norm_variance, fused_feedforward_node); + IR_NODE_LINK_TO(layer_norm_in, fused_feedforward_node); + if (pre_layer_norm) { + IR_NODE_LINK_TO(matmul_in_1, fused_feedforward_node); + } else { + IR_NODE_LINK_TO(record.dropout_out_node_2, fused_feedforward_node); + } + + IR_NODE_LINK_TO(fused_feedforward_node, layer_norm_scale_grad); + IR_NODE_LINK_TO(fused_feedforward_node, layer_norm_bias_grad); + IR_NODE_LINK_TO(fused_feedforward_node, matmul_w_grad_1); + IR_NODE_LINK_TO(fused_feedforward_node, ele_add_bias_grad_1); + IR_NODE_LINK_TO(fused_feedforward_node, matmul_w_grad_2); + IR_NODE_LINK_TO(fused_feedforward_node, ele_add_bias_grad_2); + + if (add_residual) { + GET_IR_NODE_FROM_SUBGRAPH(sum_out, sum_out, fused_feedforward_pattern); + IR_NODE_LINK_TO(fused_feedforward_node, sum_out); + } else { + if (pre_layer_norm) { + IR_NODE_LINK_TO(fused_feedforward_node, layer_norm_in_grad); + } else { + IR_NODE_LINK_TO(fused_feedforward_node, matmul_in_grad_1); + } + } + + std::unordered_set nodes_to_remove = {layer_norm_op_grad, + matmul_op_grad_1, + ele_add_op_grad_1, + act_op_grad, + matmul_op_grad_2, + ele_add_op_grad_2}; + if (use_mp) { + GET_IR_NODE_FROM_SUBGRAPH( + c_identity_op, c_identity_op, fused_feedforward_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + c_allreduce_sum_op, c_allreduce_sum_op, fused_feedforward_pattern); + nodes_to_remove.insert(c_identity_op); + nodes_to_remove.insert(c_allreduce_sum_op); + } + if (use_dropout_1) { + GET_IR_NODE_FROM_SUBGRAPH( + dropout_op_grad_1, dropout_op_grad_1, fused_feedforward_pattern); + nodes_to_remove.insert(dropout_op_grad_1); + } + if (use_dropout_2) { + GET_IR_NODE_FROM_SUBGRAPH( + dropout_op_grad_2, dropout_op_grad_2, fused_feedforward_pattern); + nodes_to_remove.insert(dropout_op_grad_2); + } + if (add_residual) { + GET_IR_NODE_FROM_SUBGRAPH( + ele_add_op_grad_3, ele_add_op_grad_3, fused_feedforward_pattern); + // Sum for gradient addition + GET_IR_NODE_FROM_SUBGRAPH(sum_op, sum_op, fused_feedforward_pattern); + nodes_to_remove.insert(ele_add_op_grad_3); + nodes_to_remove.insert(sum_op); + } + GraphSafeRemoveNodes(g, nodes_to_remove); + found_fused_feedforward_bwd_count++; + }; + + gpd(graph, handler); + AddStatis(found_fused_feedforward_bwd_count); + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fused_feedforward_pass, + paddle::framework::ir::FusedFeedForwardPass); diff --git a/paddle/fluid/framework/ir/fused_feedforward_pass.h b/paddle/fluid/framework/ir/fused_feedforward_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..6e048e8bb61eeeeff9b427898a6dd32a1c9c08d1 --- /dev/null +++ b/paddle/fluid/framework/ir/fused_feedforward_pass.h @@ -0,0 +1,92 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Fuse the FeedForward in attention + * Forward: + * 1. layer_norm -> linear1 -> activation -> dropout1 -> linear2 -> dropout2 + * -> residual_add (pre_layer_norm) + * 2. linear1 -> activation -> dropout1 -> linear2 -> dropout2 -> residual_add + * -> layer_norm (pose_layer_norm) + * other cases: may delete mp, residual_add, dropout1, dropout2 operators + * Backward: + * 1. residual_add_grad -> dropout2_grad -> linear2_grad -> dropout1_grad -> + * activation_grad -> linear1_grad -> layer_norm_grad (pre_layer_norm) + * 2. layer_norm_grad -> residual_add_grad -> dropout2_grad -> linear2_grad -> + * dropout1_grad -> activation_grad -> linear1_grad (pose_layer_norm) + * other cases: may delete mp, residual_add_grad, dropout1_grad, dropout2_grad + * operators + */ +class Graph; +class Node; + +class FusedFeedForwardPass : public FusePassBase { + public: + virtual ~FusedFeedForwardPass() {} + + protected: + // Used for pattern created variable node transfer + // between corresponding forward operator and backward operator. + struct DropoutNode { + Node *dropout_out_node_1; + Node *dropout_mask_node_1; + Node *dropout_out_node_2; + Node *dropout_mask_node_2; + DropoutNode() + : dropout_out_node_1(nullptr), + dropout_mask_node_1(nullptr), + dropout_out_node_2(nullptr), + dropout_mask_node_2(nullptr) {} + }; + typedef std::unordered_map Cache; + + const std::string scope_name{"fused_feedforward"}; + + void ApplyImpl(ir::Graph *graph) const override; + + ir::Graph *FusedFeedForwardFwd(ir::Graph *graph, + bool use_mp, + bool pre_layer_norm, + bool add_residual, + bool use_dropout_1, + bool use_dropout_2, + Cache *dropout_nodes_map) const; + + ir::Graph *FusedFeedForwardBwd(ir::Graph *graph, + bool use_mp, + bool pre_layer_norm, + bool add_residual, + bool use_dropout_1, + bool use_dropout_2, + Cache *dropout_nodes_map) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 21a0ed860bc61e4a34dc5e29e30e801d9c86088d..ffcf7f78c27e608004e92f4e64799413d06f15ac 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -113,7 +113,8 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) { if (node.Name().rfind("__control_var") == 0) continue; for (const auto &pdnode : pattern_.nodes()) { if (pdnode->Tell(&node)) { - VLOG(4) << "Node " << node.Name() << " marked as " << pdnode->name(); + VLOG(4) << "Node " << node.Name() << "(" << node.id() << ")" + << " marked as " << pdnode->name(); pdnodes2nodes_[pdnode.get()].insert(&node); } } @@ -231,7 +232,8 @@ GraphPatternDetector::DetectPatterns() { // source -> target for (Node *source : pdnodes2nodes_[edge.first]) { for (Node *target : pdnodes2nodes_[edge.second]) { - VLOG(8) << "check " << source->id() << " -- " << target->id(); + VLOG(8) << "check " << source->Name() << "(" << source->id() << ")" + << " -- " << target->Name() << "(" << target->id() << ")"; // TODO(Superjomn) add some prune strategies. for (const auto &group : pre_groups) { if (IsNodesLink(source, target)) { @@ -251,7 +253,9 @@ GraphPatternDetector::DetectPatterns() { VLOG(3) << "step " << step << " get records: " << cur_groups.size(); for (auto &group : cur_groups) { for (auto &item : group.roles) { - VLOG(4) << "node " << item.second->id() << " as " << item.first->name(); + VLOG(4) << "node " << item.second->Name() << "(" << item.second->id() + << ")" + << " as " << item.first->name(); } VLOG(4) << "========================================================="; } @@ -4011,6 +4015,443 @@ PDNode *patterns::MergeLayernormPattern::operator()(PDNode *in) { return layernorm_40_out; } +PDNode *patterns::FusedFeedForwardFwd::operator()( + paddle::framework::ir::PDNode *x_var, + std::unordered_set act_types, + bool use_mp, + bool pre_layer_norm, + bool add_residual, + bool use_dropout_1, + bool use_dropout_2) { + // Possible patterns + // 1. layer_norm -> linear1 -> activation -> dropout1 -> linear2 -> dropout2 + // -> residual_add (pre_layer_norm) + // 2. linear1 -> activation -> dropout1 -> linear2 -> dropout2 -> residual_add + // -> layer_norm (pOST_layer_norm) + // other cases: may delete residual_add, dropout1, dropout2 operators + + // intermediate input, and final pattern output + PDNode *out_var = x_var; + // LayerNorm + auto *layer_norm_op = + pattern->NewNode(layer_norm_op_repr())->assert_is_op("layer_norm"); + auto *layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_op_input("layer_norm", "Bias"); + auto *layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_op_input("layer_norm", "Scale"); + auto *layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Y"); + auto *layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto *layer_norm_variance_var = + pattern->NewNode(layer_norm_variance_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Variance"); + if (pre_layer_norm) { + out_var->assert_is_op_input("layer_norm", "X"); + layer_norm_op + ->LinksFrom({out_var, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo( + {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var}); + out_var = layer_norm_out_var; + } + + // Model parallel, do nothing in forward. + if (use_mp) { + out_var->assert_is_op_input("c_identity", "X"); + auto *c_identity_op = + pattern->NewNode(c_identity_op_repr())->assert_is_op("c_identity"); + auto *c_identity_out_var = pattern->NewNode(c_identity_out_repr()) + ->assert_is_op_output("c_identity", "Out"); + c_identity_op->LinksFrom({out_var}).LinksTo({c_identity_out_var}); + out_var = c_identity_out_var; + } + + // Linear1 + out_var->assert_is_op_input("matmul_v2", "X"); + auto *matmul_op_1 = + pattern->NewNode(matmul_op_1_repr())->assert_is_op("matmul_v2"); + auto *matmul_w_var_1 = pattern->NewNode(matmul_w_1_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "Y"); + auto *matmul_out_var_1 = pattern->NewNode(matmul_out_1_repr()) + ->assert_is_op_output("matmul_v2", "Out"); + matmul_op_1->LinksFrom({out_var, matmul_w_var_1}).LinksTo({matmul_out_var_1}); + out_var = matmul_out_var_1; + + out_var->assert_is_op_input("elementwise_add", "X"); + auto *ele_add_op_1 = + pattern->NewNode(ele_add_op_1_repr())->assert_is_op("elementwise_add"); + auto *ele_add_bias_var_1 = pattern->NewNode(ele_add_bias_1_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); + auto *ele_add_out_var_1 = pattern->NewNode(ele_add_out_1_repr()) + ->assert_is_op_output("elementwise_add", "Out"); + ele_add_op_1->LinksFrom({out_var, ele_add_bias_var_1}) + .LinksTo({ele_add_out_var_1}); + out_var = ele_add_out_var_1; + + // Activation + out_var->assert_is_ops_input(act_types); + auto *act_op = pattern->NewNode(act_op_repr())->assert_is_ops(act_types); + auto *act_out_var = + pattern->NewNode(act_out_repr())->assert_is_ops_output(act_types, "Out"); + act_op->LinksFrom({out_var}).LinksTo({act_out_var}); + out_var = act_out_var; + + // Dropout1 + if (use_dropout_1) { + out_var->assert_is_op_input("dropout", "X"); + auto *dropout_op_1 = + pattern->NewNode(dropout_op_1_repr())->assert_is_op("dropout"); + auto *dropout_mask_var_1 = pattern->NewNode(dropout_mask_1_repr()) + ->assert_is_op_output("dropout", "Mask"); + auto *dropout_out_var_1 = pattern->NewNode(dropout_out_1_repr()) + ->assert_is_op_output("dropout", "Out"); + dropout_op_1->LinksFrom({out_var}).LinksTo( + {dropout_mask_var_1, dropout_out_var_1}); + out_var = dropout_out_var_1; + } + + // Linear2 + out_var->assert_is_op_input("matmul_v2", "X"); + auto *matmul_op_2 = + pattern->NewNode(matmul_op_2_repr())->assert_is_op("matmul_v2"); + auto *matmul_w_var_2 = + pattern->NewNode(matmul_w_2_repr())->assert_is_op_input("matmul_v2", "Y"); + auto *matmul_out_var_2 = pattern->NewNode(matmul_out_2_repr()) + ->assert_is_op_output("matmul_v2", "Out"); + matmul_op_2->LinksFrom({out_var, matmul_w_var_2}).LinksTo({matmul_out_var_2}); + out_var = matmul_out_var_2; + + // Model parallel, do nothing in forward. + if (use_mp) { + out_var->assert_is_op_input("c_allreduce_sum", "X"); + auto *c_allreduce_sum_op = pattern->NewNode(c_allreduce_sum_op_repr()) + ->assert_is_op("c_allreduce_sum"); + auto *c_allreduce_sum_out_var = + pattern->NewNode(c_allreduce_sum_out_repr()) + ->assert_is_op_output("c_allreduce_sum", "Out"); + c_allreduce_sum_op->LinksFrom({out_var}).LinksTo({c_allreduce_sum_out_var}); + out_var = c_allreduce_sum_out_var; + } + + out_var->assert_is_op_input("elementwise_add", "X"); + auto *ele_add_op_2 = + pattern->NewNode(ele_add_op_2_repr())->assert_is_op("elementwise_add"); + auto *ele_add_bias_var_2 = pattern->NewNode(ele_add_bias_2_repr()) + ->assert_is_op_input("elementwise_add", "Y"); + auto *ele_add_out_var_2 = pattern->NewNode(ele_add_out_2_repr()) + ->assert_is_op_output("elementwise_add", "Out"); + ele_add_op_2->LinksFrom({out_var, ele_add_bias_var_2}) + .LinksTo({ele_add_out_var_2}); + out_var = ele_add_out_var_2; + + // Dropout 2 + if (use_dropout_2) { + out_var->assert_is_op_input("dropout", "X"); + auto *dropout_op_2 = + pattern->NewNode(dropout_op_2_repr())->assert_is_op("dropout"); + auto *dropout_mask_var_2 = pattern->NewNode(dropout_mask_2_repr()) + ->assert_is_op_output("dropout", "Mask"); + auto *dropout_out_var_2 = pattern->NewNode(dropout_out_2_repr()) + ->assert_is_op_output("dropout", "Out"); + dropout_op_2->LinksFrom({out_var}).LinksTo( + {dropout_mask_var_2, dropout_out_var_2}); + out_var = dropout_out_var_2; + } + + // Residual Add + if (add_residual) { + out_var->assert_is_op_input("elementwise_add", "X"); + x_var->assert_is_op_input("elementwise_add", "Y"); + auto *ele_add_op_3 = + pattern->NewNode(ele_add_op_3_repr())->assert_is_op("elementwise_add"); + auto *ele_add_out_var_3 = + pattern->NewNode(ele_add_out_3_repr()) + ->assert_is_op_output("elementwise_add", "Out"); + ele_add_op_3->LinksFrom({out_var, x_var}).LinksTo({ele_add_out_var_3}); + out_var = ele_add_out_var_3; + } + + // Post LayerNorm + if (!pre_layer_norm) { + out_var->assert_is_op_input("layer_norm", "X"); + layer_norm_op + ->LinksFrom({out_var, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo( + {layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var}); + out_var = layer_norm_out_var; + } + return out_var; +} + +PDNode *patterns::FusedFeedForwardBwd::operator()( + paddle::framework::ir::PDNode *x_grad, + std::unordered_set act_grad_types, + bool use_mp, + bool pre_layer_norm, + bool add_residual, + bool use_dropout_1, + bool use_dropout_2) { + // Possible patterns + // 1. residual_add_grad -> dropout2_grad -> linear2_grad -> dropout1_grad -> + // activation_grad -> linear1_grad -> layer_norm_grad + // 2. layer_norm_grad -> residual_add_grad -> dropout2_grad -> linear2_grad -> + // dropout1_grad -> activation_grad -> linear1_grad + // other cases: may delete residual_add_grad, dropout1_grad, dropout2_grad + // operators + + // intermediate input_grad, and final pattern ouput_grad + PDNode *out_grad = x_grad; + // LayerNorm: in["Mean", "Variance", "Scale", "Bias", "Y@GRAD"], + // out["X@GRAD", "Scale@GRAD", "Bias@GRAD"] + auto *layer_norm_op_grad = pattern->NewNode(layer_norm_op_grad_repr()) + ->assert_is_op("layer_norm_grad"); + auto *layer_norm_in_var = pattern->NewNode(layer_norm_in_repr()) + ->assert_is_op_input("layer_norm_grad", "X"); + auto *layer_norm_mean_var = + pattern->NewNode(layer_norm_mean_repr()) + ->assert_is_op_input("layer_norm_grad", "Mean"); + auto *layer_norm_variance_var = + pattern->NewNode(layer_norm_variance_repr()) + ->assert_is_op_input("layer_norm_grad", "Variance"); + auto *layer_norm_scale_var = + pattern->NewNode(layer_norm_scale_repr()) + ->assert_is_op_input("layer_norm_grad", "Scale"); + auto *layer_norm_bias_var = + pattern->NewNode(layer_norm_bias_repr()) + ->assert_is_op_input("layer_norm_grad", "Bias"); + auto *layer_norm_in_grad = + pattern->NewNode(layer_norm_in_grad_repr()) + ->assert_is_op_output("layer_norm_grad", GradVarName("X")); + auto *layer_norm_scale_grad = + pattern->NewNode(layer_norm_scale_grad_repr()) + ->assert_is_op_output("layer_norm_grad", GradVarName("Scale")); + auto *layer_norm_bias_grad = + pattern->NewNode(layer_norm_bias_grad_repr()) + ->assert_is_op_output("layer_norm_grad", GradVarName("Bias")); + // post_layer_norm + if (!pre_layer_norm) { + out_grad->assert_is_op_input("layer_norm_grad", GradVarName("Y")); + layer_norm_op_grad + ->LinksFrom({out_grad, + layer_norm_in_var, + layer_norm_mean_var, + layer_norm_variance_var, + layer_norm_scale_var, + layer_norm_bias_var}) + .LinksTo( + {layer_norm_in_grad, layer_norm_scale_grad, layer_norm_bias_grad}); + out_grad = layer_norm_in_grad; + } + // partial input_grad of residual_add + PDNode *tmp = nullptr; + auto *matmul_in_var_1 = pattern->NewNode(matmul_in_1_repr()) + ->assert_is_op_input("matmul_v2_grad", "X"); + if (add_residual) { + // Residual Add: in["Out@GRAD", "X", "Y"], out["X@GRAD", "Y@GRAD"] + out_grad->assert_is_op_input("elementwise_add_grad", GradVarName("Out")); + auto *ele_add_op_grad_3 = pattern->NewNode(ele_add_op_grad_3_repr()) + ->assert_is_op("elementwise_add_grad"); + auto *ele_add_in_var_3 = + pattern->NewNode(ele_add_in_3_repr()) + ->assert_is_op_input("elementwise_add_grad", "X"); + auto *ele_add_in_grad_3 = + pattern->NewNode(ele_add_in_grad_3_repr()) + ->assert_is_op_output("elementwise_add_grad", GradVarName("X")); + auto *ele_add_bias_grad_3 = + pattern->NewNode(ele_add_bias_grad_3_repr()) + ->assert_is_op_output("elementwise_add_grad", GradVarName("Y")); + tmp = ele_add_bias_grad_3; + if (pre_layer_norm) { + ele_add_op_grad_3 + ->LinksFrom({out_grad, ele_add_in_var_3, layer_norm_in_var}) + .LinksTo({ele_add_in_grad_3, ele_add_bias_grad_3}); + } else { + ele_add_op_grad_3 + ->LinksFrom({out_grad, ele_add_in_var_3, matmul_in_var_1}) + .LinksTo({ele_add_in_grad_3, ele_add_bias_grad_3}); + } + out_grad = ele_add_in_grad_3; + } + + // Dropout 2: in["Out@GRAD", "Mask"], out["X@GRAD"] + if (use_dropout_2) { + out_grad->assert_is_op_input("dropout_grad", GradVarName("Out")); + auto *dropout_op_grad_2 = pattern->NewNode(dropout_op_grad_2_repr()) + ->assert_is_op("dropout_grad"); + auto *dropout_mask_grad_2 = + pattern->NewNode(dropout_mask_2_repr()) + ->assert_is_op_input("dropout_grad", "Mask"); + auto *dropout_in_grad_2 = + pattern->NewNode(dropout_in_grad_2_repr()) + ->assert_is_op_output("dropout_grad", GradVarName("X")); + dropout_op_grad_2->LinksFrom({out_grad, dropout_mask_grad_2}) + .LinksTo({dropout_in_grad_2}); + out_grad = dropout_in_grad_2; + } + + // Linear 2: + // elementwise_add: in["Out@GRAD", "X", "Y"], out["X@GRAD", "Y@GRAD"] + out_grad->assert_is_op_input("elementwise_add_grad", GradVarName("Out")); + auto *ele_add_op_grad_2 = pattern->NewNode(ele_add_op_grad_2_repr()) + ->assert_is_op("elementwise_add_grad"); + auto *ele_add_in_var_2 = + pattern->NewNode(ele_add_in_2_repr()) + ->assert_is_op_input("elementwise_add_grad", "X"); + auto *ele_add_bias_var_2 = + pattern->NewNode(ele_add_bias_2_repr()) + ->assert_is_op_input("elementwise_add_grad", "Y"); + auto *ele_add_in_grad_2 = + pattern->NewNode(ele_add_in_grad_2_repr()) + ->assert_is_op_output("elementwise_add_grad", GradVarName("X")); + auto *ele_add_bias_grad_2 = + pattern->NewNode(ele_add_bias_grad_2_repr()) + ->assert_is_op_output("elementwise_add_grad", GradVarName("Y")); + ele_add_op_grad_2->LinksFrom({out_grad, ele_add_in_var_2, ele_add_bias_var_2}) + .LinksTo({ele_add_in_grad_2, ele_add_bias_grad_2}); + out_grad = ele_add_in_grad_2; + + // Model parallel, do nothing in backward. + if (use_mp) { + out_grad->assert_is_op_input("c_identity", "X"); + auto *c_identity_op = + pattern->NewNode(c_identity_op_repr())->assert_is_op("c_identity"); + auto *c_identity_out_grad = pattern->NewNode(c_identity_out_repr()) + ->assert_is_op_output("c_identity", "Out"); + c_identity_op->LinksFrom({out_grad}).LinksTo({c_identity_out_grad}); + out_grad = c_identity_out_grad; + } + + // matmul_v2: in["Out@GRAD", "X", "Y"], out["X@GRAD", "Y@GRAD"] + out_grad->assert_is_op_input("matmul_v2_grad", GradVarName("Out")); + auto *matmul_op_grad_2 = + pattern->NewNode(matmul_op_grad_2_repr())->assert_is_op("matmul_v2_grad"); + auto *matmul_in_var_2 = pattern->NewNode(matmul_in_2_repr()) + ->assert_is_op_input("matmul_v2_grad", "X"); + auto *matmul_w_var_2 = pattern->NewNode(matmul_w_2_repr()) + ->assert_is_op_input("matmul_v2_grad", "Y"); + auto *matmul_in_grad_2 = + pattern->NewNode(matmul_in_grad_2_repr()) + ->assert_is_op_output("matmul_v2_grad", GradVarName("X")); + auto *matmul_w_grad_2 = + pattern->NewNode(matmul_w_grad_2_repr()) + ->assert_is_op_output("matmul_v2_grad", GradVarName("Y")); + matmul_op_grad_2->LinksFrom({out_grad, matmul_in_var_2, matmul_w_var_2}) + .LinksTo({matmul_in_grad_2, matmul_w_grad_2}); + out_grad = matmul_in_grad_2; + + // Dropout 1: in["Out@GRAD", "Mask"], out["X@GRAD"] + if (use_dropout_1) { + out_grad->assert_is_op_input("dropout_grad", GradVarName("Out")); + auto *dropout_op_grad_1 = pattern->NewNode(dropout_op_grad_1_repr()) + ->assert_is_op("dropout_grad"); + auto *dropout_mask_var_1 = pattern->NewNode(dropout_mask_1_repr()) + ->assert_is_op_input("dropout_grad", "Mask"); + auto *dropout_in_grad_1 = + pattern->NewNode(dropout_in_grad_1_repr()) + ->assert_is_op_output("dropout_grad", GradVarName("X")); + dropout_op_grad_1->LinksFrom({out_grad, dropout_mask_var_1}) + .LinksTo({dropout_in_grad_1}); + out_grad = dropout_in_grad_1; + } + + // Activation: in["Out", "Out@GRAD"], out["X@GRAD"] + out_grad->assert_is_ops_input(act_grad_types, GradVarName("Out")); + auto *act_op_grad = + pattern->NewNode(act_op_grad_repr())->assert_is_ops(act_grad_types); + auto *act_in_var = + pattern->NewNode(act_in_repr())->assert_is_ops_input(act_grad_types, "X"); + auto *act_in_grad = + pattern->NewNode(act_in_grad_repr()) + ->assert_is_ops_output(act_grad_types, GradVarName("X")); + act_op_grad->LinksFrom({out_grad, act_in_var}).LinksTo({act_in_grad}); + out_grad = act_in_grad; + + // Linear 1: + // elementwise_add: in["Out@GRAD", "X", "Y"], out["X@GRAD", "Y@GRAD"] + out_grad->assert_is_op_input("elementwise_add_grad", GradVarName("Out")); + auto *ele_add_op_grad_1 = pattern->NewNode(ele_add_op_grad_1_repr()) + ->assert_is_op("elementwise_add_grad"); + auto *ele_add_in_var_1 = + pattern->NewNode(ele_add_in_1_repr()) + ->assert_is_op_input("elementwise_add_grad", "X"); + auto *ele_add_bias_var_1 = + pattern->NewNode(ele_add_bias_1_repr()) + ->assert_is_op_input("elementwise_add_grad", "Y"); + auto *ele_add_in_grad_1 = + pattern->NewNode(ele_add_in_grad_1_repr()) + ->assert_is_op_output("elementwise_add_grad", GradVarName("X")); + auto *ele_add_bias_grad_1 = + pattern->NewNode(ele_add_bias_grad_1_repr()) + ->assert_is_op_output("elementwise_add_grad", GradVarName("Y")); + ele_add_op_grad_1->LinksFrom({out_grad, ele_add_in_var_1, ele_add_bias_var_1}) + .LinksTo({ele_add_in_grad_1, ele_add_bias_grad_1}); + out_grad = ele_add_in_grad_1; + // matmul_v2: in["Out@GRAD", "X", "Y"], out["X@GRAD", "Y@GRAD"] + out_grad->assert_is_op_input("matmul_v2_grad", GradVarName("Out")); + auto *matmul_op_grad_1 = + pattern->NewNode(matmul_op_grad_1_repr())->assert_is_op("matmul_v2_grad"); + // auto *matmul_in_var_1 = pattern->NewNode(matmul_in_1_repr()) + // ->assert_is_op_input("matmul_v2_grad", + // "X"); + auto *matmul_w_var_1 = pattern->NewNode(matmul_w_1_repr()) + ->assert_is_op_input("matmul_v2_grad", "Y"); + auto *matmul_in_grad_1 = + pattern->NewNode(matmul_in_grad_1_repr()) + ->assert_is_op_output("matmul_v2_grad", GradVarName("X")); + auto *matmul_w_grad_1 = + pattern->NewNode(matmul_w_grad_1_repr()) + ->assert_is_op_output("matmul_v2_grad", GradVarName("Y")); + matmul_op_grad_1->LinksFrom({out_grad, matmul_in_var_1, matmul_w_var_1}) + .LinksTo({matmul_in_grad_1, matmul_w_grad_1}); + out_grad = matmul_in_grad_1; + + // Model parallel, all_reduce in backward. + if (use_mp) { + out_grad->assert_is_op_input("c_allreduce_sum", "X"); + auto *c_allreduce_sum_op = pattern->NewNode(c_allreduce_sum_op_repr()) + ->assert_is_op("c_allreduce_sum"); + auto *c_allreduce_sum_out_grad = + pattern->NewNode(c_allreduce_sum_out_repr()) + ->assert_is_op_output("c_allreduce_sum", "Out"); + c_allreduce_sum_op->LinksFrom({out_grad}) + .LinksTo({c_allreduce_sum_out_grad}); + out_grad = c_allreduce_sum_out_grad; + } + + // pre LayerNorm + if (pre_layer_norm) { + out_grad->assert_is_op_input("layer_norm_grad", GradVarName("Y")); + layer_norm_op_grad + ->LinksFrom({out_grad, + layer_norm_in_var, + layer_norm_mean_var, + layer_norm_variance_var, + layer_norm_scale_var, + layer_norm_bias_var}) + .LinksTo( + {layer_norm_in_grad, layer_norm_scale_grad, layer_norm_bias_grad}); + out_grad = layer_norm_in_grad; + } + + // sum for final gradient + if (add_residual) { + auto *sum_op = pattern->NewNode(sum_op_repr())->assert_is_op("sum"); + auto *sum_out = + pattern->NewNode(sum_out_repr())->assert_is_op_output("sum", "Out"); + sum_op->LinksFrom({tmp, out_grad}).LinksTo({sum_out}); + out_grad = sum_out; + } + + return out_grad; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 7985e5b2b3501994392f2c46dd7bed6c3d6249a8..14d0d4e7b8e3a7379be4fad1f1ad3414fa2525a3 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -2156,6 +2156,133 @@ struct AddSupportInt8 : public PatternBase { PATTERN_DECL_NODE(quant_out); }; +// The following patterns are used to fuse feedforward in forward +// 1. layer_norm -> linear1 -> activation -> dropout1 -> linear2 -> dropout2 +// -> residual_add (pre_layer_norm) +// 2. linear1 -> activation -> dropout1 -> linear2 -> dropout2 -> residual_add +// -> layer_norm (pOST_layer_norm) +// other cases: may delete residual_add, dropout1, dropout2 operators +struct FusedFeedForwardFwd : public PatternBase { + FusedFeedForwardFwd(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "fused_feedforward_fwd") {} + + PDNode* operator()(PDNode* x, + std::unordered_set act_types, + bool use_mp, + bool pre_layer_norm, + bool add_residual, + bool use_dropout_1, + bool use_dropout_2); + +#ifndef FEEDFORWARD_LINEAR_DROPOUT_NODE +#define FEEDFORWARD_LINEAR_DROPOUT_NODE(suffix__) \ + PATTERN_DECL_NODE(matmul_op_##suffix__); \ + PATTERN_DECL_NODE(matmul_w_##suffix__); \ + PATTERN_DECL_NODE(matmul_out_##suffix__); \ + PATTERN_DECL_NODE(ele_add_op_##suffix__); \ + PATTERN_DECL_NODE(ele_add_bias_##suffix__); \ + PATTERN_DECL_NODE(ele_add_out_##suffix__); \ + PATTERN_DECL_NODE(dropout_op_##suffix__); \ + PATTERN_DECL_NODE(dropout_out_##suffix__); \ + PATTERN_DECL_NODE(dropout_mask_##suffix__); + + // LayerNorm: layer_norm + PATTERN_DECL_NODE(layer_norm_op); + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(layer_norm_mean); + PATTERN_DECL_NODE(layer_norm_variance); + // Mode parallelism + PATTERN_DECL_NODE(c_identity_op); + PATTERN_DECL_NODE(c_identity_out); + PATTERN_DECL_NODE(c_allreduce_sum_op); + PATTERN_DECL_NODE(c_allreduce_sum_out); + // Linear 1 and Dropout 1: matmul_v2 + elementwise_add + dropout + FEEDFORWARD_LINEAR_DROPOUT_NODE(1); + // Activation Grad: gelu or relu + PATTERN_DECL_NODE(act_op); + PATTERN_DECL_NODE(act_out); + // Linear 2 and Dropout 2: matmul_v2 + elementwise_add + dropout + FEEDFORWARD_LINEAR_DROPOUT_NODE(2); + // ResidualAdd: elementwise_add + PATTERN_DECL_NODE(ele_add_op_3); + PATTERN_DECL_NODE(ele_add_out_3); +#undef FEEDFORWARD_LINEAR_DROPOUT_NODE +#endif +}; + +// The following patterns are used to fuse feedforward in backward +// 1. residual_add_grad -> dropout2_grad -> linear2_grad -> dropout1_grad -> +// activation_grad -> linear1_grad -> layer_norm_grad +// 2. layer_norm_grad -> residual_add_grad -> dropout2_grad -> linear2_grad -> +// dropout1_grad -> activation_grad -> linear1_grad +// other cases: may delete residual_add_grad, dropout1_grad, dropout2_grad +// operators +struct FusedFeedForwardBwd : public PatternBase { + FusedFeedForwardBwd(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "fused_feedforward_bwd") {} + + PDNode* operator()(PDNode* x, + std::unordered_set act_grad_types, + bool use_mp, + bool pre_layer_norm, + bool add_residual, + bool use_dropout_1, + bool use_dropout_2); +#ifndef FEEDFORWARD_LINEAR_DROPOUT_GRAD_NODE +#define FEEDFORWARD_LINEAR_DROPOUT_GRAD_NODE(suffix__) \ + PATTERN_DECL_NODE(matmul_op_grad_##suffix__); \ + PATTERN_DECL_NODE(matmul_in_##suffix__); \ + PATTERN_DECL_NODE(matmul_w_##suffix__); \ + PATTERN_DECL_NODE(matmul_in_grad_##suffix__); \ + PATTERN_DECL_NODE(matmul_w_grad_##suffix__); \ + PATTERN_DECL_NODE(ele_add_op_grad_##suffix__); \ + PATTERN_DECL_NODE(ele_add_in_##suffix__); \ + PATTERN_DECL_NODE(ele_add_bias_##suffix__); \ + PATTERN_DECL_NODE(ele_add_in_grad_##suffix__); \ + PATTERN_DECL_NODE(ele_add_bias_grad_##suffix__); \ + PATTERN_DECL_NODE(dropout_op_grad_##suffix__); \ + PATTERN_DECL_NODE(dropout_mask_##suffix__); \ + PATTERN_DECL_NODE(dropout_in_grad_##suffix__); + + // LayerNorm Grad: layer_norm_grad + PATTERN_DECL_NODE(layer_norm_op_grad); + PATTERN_DECL_NODE(layer_norm_in); + PATTERN_DECL_NODE(layer_norm_mean); + PATTERN_DECL_NODE(layer_norm_variance); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_in_grad); + PATTERN_DECL_NODE(layer_norm_scale_grad); + PATTERN_DECL_NODE(layer_norm_bias_grad); + // Mode parallelism + PATTERN_DECL_NODE(c_identity_op); + PATTERN_DECL_NODE(c_identity_out); + PATTERN_DECL_NODE(c_allreduce_sum_op); + PATTERN_DECL_NODE(c_allreduce_sum_out); + // Linear 1 and Dropout 1: matmul_v2_grad + elementwise_add_grad + + // dropout_grad + FEEDFORWARD_LINEAR_DROPOUT_GRAD_NODE(1); + // Activation Grad: gelu_grad or relu_add + PATTERN_DECL_NODE(act_op_grad); + PATTERN_DECL_NODE(act_in); + PATTERN_DECL_NODE(act_in_grad); + // Linear 2 and Dropout 2: matmul_v2_grad + elementwise_add_grad + + // dropout_grad + FEEDFORWARD_LINEAR_DROPOUT_GRAD_NODE(2); + // Residual Add: elementwise_add + PATTERN_DECL_NODE(ele_add_op_grad_3); + PATTERN_DECL_NODE(ele_add_in_3); + PATTERN_DECL_NODE(ele_add_bias_3); + PATTERN_DECL_NODE(ele_add_in_grad_3); + PATTERN_DECL_NODE(ele_add_bias_grad_3); + PATTERN_DECL_NODE(sum_op); + PATTERN_DECL_NODE(sum_out); + +#undef FEEDFORWARD_LINEAR_DROPOUT_GRAD_NODE +#endif +}; } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/pybind/parallel_executor.cc b/paddle/fluid/pybind/parallel_executor.cc index d23b9e568148725742e0ccf6f6d81141f3558f5a..d1938238a99824cc33ca6be44f65bc1a11740429 100644 --- a/paddle/fluid/pybind/parallel_executor.cc +++ b/paddle/fluid/pybind/parallel_executor.cc @@ -723,6 +723,32 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT build_strategy = static.BuildStrategy() build_strategy.fused_attention = True )DOC") + .def_property( + "fused_feedforward", + [](const BuildStrategy &self) { return self.fused_feedforward_; }, + [](BuildStrategy &self, bool b) { + PADDLE_ENFORCE_NE(self.IsFinalized(), + true, + platform::errors::PreconditionNotMet( + "BuildStrategy has been finlaized, cannot be " + "configured again.")); + self.fused_feedforward_ = b; + }, + R"DOC((bool, optional): fused_feedforward indicate whether + to fuse the whole feed_forward part with one op, + it may make the execution faster. Default is False. + + Examples: + .. code-block:: python + + import paddle + import paddle.static as static + + paddle.enable_static() + + build_strategy = static.BuildStrategy() + build_strategy.fused_feedforward = True + )DOC") .def_property( "fuse_bn_act_ops", [](const BuildStrategy &self) { return self.fuse_bn_act_ops_; }, diff --git a/python/paddle/distributed/passes/cpp_pass.py b/python/paddle/distributed/passes/cpp_pass.py index 3a791610a51088a99bf9b584ea357af5322d6566..c0da3050463b4bbe8432d852157f0c44c4aba959 100755 --- a/python/paddle/distributed/passes/cpp_pass.py +++ b/python/paddle/distributed/passes/cpp_pass.py @@ -84,6 +84,19 @@ class FusedAttentionPass(CPPPassWrapper): return PassType.FUSION_OPT +@register_pass("fused_feedforward") +class FusedFeedforwardPass(CPPPassWrapper): + def __init__(self): + super().__init__() + + @property + def cpp_name(self): + return "fused_feedforward_pass" + + def _type(self): + return PassType.FUSION_OPT + + @register_pass("fuse_gemm_epilogue") class FuseGemmEpiloguePass(CPPPassWrapper): def __init__(self): diff --git a/python/paddle/distributed/passes/pass_base.py b/python/paddle/distributed/passes/pass_base.py index 1996a3bbf064430d4af30e92ac11f5eef39a9c1d..fca239b41dcc9626049c39d5c20035480d128e80 100755 --- a/python/paddle/distributed/passes/pass_base.py +++ b/python/paddle/distributed/passes/pass_base.py @@ -253,6 +253,7 @@ PassBase._PASS_PROCESS_ORDER_LIST = [ "fuse_bn_add_act", "fuse_bn_act", "fused_attention", + "fused_feedforward", "fuse_gemm_epilogue", "fuse_optimizer", ] diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index ddb96476f5c50d06dcf68d79fc852a6f67f5f530..b5b1e74a942e5c4bac73d2a7be58009a5261a450 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -76,6 +76,7 @@ if(NOT WITH_GPU) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api) list(REMOVE_ITEM TEST_OPS test_fused_attention_pass) + list(REMOVE_ITEM TEST_OPS test_fused_feedforward_pass) endif() list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_base_list.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_base_list.py index 79db714743491a20bd2e5e21615d03d599fb7ab1..f1780b92c39b62ac0e1c85b02e5ee73043a8b581 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_base_list.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_base_list.py @@ -89,6 +89,7 @@ class TestFusedPassBaseList(unittest.TestCase): [ "fuse_bn_act", "fused_attention", + "fused_feedforward", "fuse_optimizer", "fuse_gemm_epilogue", "fuse_bn_add_act", diff --git a/python/paddle/fluid/tests/unittests/test_fused_feedforward_pass.py b/python/paddle/fluid/tests/unittests/test_fused_feedforward_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..f071dc24d0abc32837a111db9ef5f2a6996a00ea --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_feedforward_pass.py @@ -0,0 +1,172 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +import paddle.fluid.core as core +import paddle.nn as nn +from paddle.distributed.passes import PassManager, new_pass + +paddle.enable_static() + + +class FeedForward(nn.Layer): + def __init__( + self, + in_features, + hidden_features, + out_features, + drop_prob=0.1, + act_layer=nn.GELU, + pre_layer_norm=True, + add_residual=True, + use_dropout_1=True, + use_dropout_2=True, + ): + super(FeedForward, self).__init__() + self.in_features = in_features + self.hidden_features = hidden_features + self.in_features = out_features + self.pre_layer_norm = pre_layer_norm + self.add_residual = add_residual + self.use_dropout_1 = use_dropout_1 + self.use_dropout_2 = use_dropout_2 + + self.fc1 = nn.Linear(in_features, in_features) + self.fc2 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc3 = nn.Linear(hidden_features, out_features) + self.drop1 = nn.Dropout(drop_prob) + self.drop2 = nn.Dropout(drop_prob) + self.norm = nn.LayerNorm(in_features, epsilon=1e-5) + self.fc4 = nn.Linear(out_features, out_features) + + def forward(self, x): + x = self.fc1(x) + residual = x + if self.pre_layer_norm: + x = self.norm(x) + x = self.fc2(x) + x = self.act(x) + if self.use_dropout_1: + x = self.drop1(x) + x = self.fc3(x) + if self.use_dropout_2: + x = self.drop2(x) + if self.add_residual: + x += residual + if not self.pre_layer_norm: + x = self.norm(x) + x = self.fc4(x) + + return x + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestFusedFeedforwadPass(unittest.TestCase): + def setUp(self): + self.pre_layer_norm = True + self.add_residual = True + self.use_dropout_1 = True + self.use_dropout_2 = True + + def get_value(self, use_pass=False): + batch_size = 2 + in_features = 768 + hidden_features = 3072 + out_features = 768 + act_layer = nn.GELU + pre_layer_norm = self.pre_layer_norm + add_residual = self.add_residual + use_dropout_1 = self.use_dropout_1 + use_dropout_2 = self.use_dropout_2 + + np.random.seed(1234) + x_data = np.random.rand(batch_size, in_features, in_features).astype( + 'float32' + ) + + main_prog = paddle.static.Program() + main_prog.random_seed = 1234 + startup_prog = paddle.static.Program() + startup_prog.random_seed = 1234 + + with paddle.static.program_guard(main_prog, startup_prog): + data = paddle.static.data( + name="x", + shape=[2, in_features, in_features], + dtype='float32', + ) + + feed_forward = FeedForward( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + drop_prob=1e-10, + act_layer=act_layer, + pre_layer_norm=pre_layer_norm, + add_residual=add_residual, + use_dropout_1=use_dropout_1, + use_dropout_2=use_dropout_2, + ) + + out = feed_forward(data) + + loss = paddle.mean(out) + sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001) + sgd_optimizer.minimize(loss) + + if use_pass: + pass_manager = PassManager([new_pass("fused_feedforward")]) + pass_manager.apply([main_prog], [startup_prog]) + + ops = main_prog.global_block().ops + assert 'fused_feedforward' in [op.type for op in ops] + assert 'fused_feedforward_grad' in [op.type for op in ops] + + exe = paddle.static.Executor(paddle.CUDAPlace(0)) + exe.run(startup_prog) + + for i in range(2): + ret_loss = exe.run( + main_prog, feed={"x": x_data}, fetch_list=[loss.name] + ) + + return ret_loss + + def test_pass(self): + for pre_layer_norm in [True, False]: + for add_residual in [True, False]: + for use_dropout_1 in [True, False]: + for use_dropout_2 in [True, False]: + if not pre_layer_norm and not add_residual: + continue + if not use_dropout_1 and not use_dropout_2: + continue + self.pre_layer_norm = pre_layer_norm + self.add_residual = add_residual + self.use_dropout_1 = use_dropout_1 + self.use_dropout_2 = use_dropout_2 + ret_loss = self.get_value() + ret_loss_fused = self.get_value(use_pass=True) + assert np.allclose(ret_loss, ret_loss_fused) + + +if __name__ == "__main__": + unittest.main()