From b0ece26623f1424c4c2568939160e0c482d169b4 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Tue, 10 Jan 2023 10:51:20 +0800 Subject: [PATCH] [Fuse attention pass] Forward pattern. (#49621) --- paddle/fluid/framework/details/CMakeLists.txt | 1 + .../fluid/framework/details/build_strategy.cc | 7 + .../fluid/framework/details/build_strategy.h | 3 + .../framework/distributed_strategy.proto | 1 + paddle/fluid/framework/ir/CMakeLists.txt | 4 + .../framework/ir/fused_attention_pass.cc | 448 ++++++++++++++++++ .../fluid/framework/ir/fused_attention_pass.h | 176 +++++++ paddle/fluid/pybind/parallel_executor.cc | 26 + python/paddle/distributed/passes/cpp_pass.py | 13 + .../fluid/tests/unittests/CMakeLists.txt | 1 + .../unittests/test_fused_attention_pass.py | 164 +++++++ 11 files changed, 844 insertions(+) create mode 100644 paddle/fluid/framework/ir/fused_attention_pass.cc create mode 100644 paddle/fluid/framework/ir/fused_attention_pass.h create mode 100644 python/paddle/fluid/tests/unittests/test_fused_attention_pass.py diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 8ce39f0db7..5d31b443c1 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -382,6 +382,7 @@ set(IR_PASS_DEPS graph_to_program_pass fix_op_run_order_pass fuse_gemm_epilogue_pass + fused_attention_pass delete_dropout_op_pass) if(WITH_CINN) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 486770cdbd..47a262ea35 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -187,6 +187,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass"); #endif +#ifdef PADDLE_WITH_CUDA + AppendPassWithCheck(strategy_.fused_attention_, "fused_attention_pass"); +#endif + #if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) AppendPassWithCheck(strategy_.fuse_gemm_epilogue_, "fuse_gemm_epilogue_pass"); @@ -519,6 +523,9 @@ USE_PASS(fuse_all_reduce_op_pass); USE_PASS(runtime_context_cache_pass); USE_PASS(add_reader_dependency_pass); USE_PASS(delete_dropout_op_x_pass); +#ifdef PADDLE_WITH_CUDA +USE_PASS(fused_attention_pass); +#endif #ifdef PADDLE_WITH_CINN USE_PASS(build_cinn_pass); #endif diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 4d51099529..29e390bf0f 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -129,6 +129,8 @@ struct BuildStrategy { bool sync_batch_norm_{false}; // Fuse GEMM+Epilogue via cublasLt epilogue. bool fuse_gemm_epilogue_{false}; + // Fused multi head attention + bool fused_attention_{false}; // mkldnn_enabled_op_types specify the operator type list to // use MKLDNN acceleration. It is null in default, means @@ -261,6 +263,7 @@ inline std::ostream &operator<<(std::ostream &os, os << "fuse_broadcast_ops_: " << strategy.fuse_broadcast_ops_ << std::endl; os << "sync_batch_norm_: " << strategy.sync_batch_norm_ << std::endl; os << "fuse_gemm_epilogue_: " << strategy.fuse_gemm_epilogue_ << std::endl; + os << "fused_attention_: " << strategy.fused_attention_ << std::endl; os << "mkldnn_enabled_op_types_: "; for (auto str : strategy.mkldnn_enabled_op_types_) { os << str << ", "; diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 52c24fffc7..27bc7c7030 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -124,6 +124,7 @@ message BuildStrategy { optional int32 reduce_strategy = 15 [ default = 0 ]; optional bool fuse_gemm_epilogue = 16 [ default = false ]; optional string debug_graphviz_path = 17; + optional bool fused_attention = 18 [ default = false]; } message ExecutionStrategy { diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 8a22eb87db..1a84e815e0 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -223,6 +223,10 @@ cc_library( fuse_gemm_epilogue_pass SRCS fuse_gemm_epilogue_pass.cc DEPS pass graph_pattern_detector) +cc_library( + fused_attention_pass + SRCS fused_attention_pass.cc + DEPS pass graph_pattern_detector) cc_library( fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc diff --git a/paddle/fluid/framework/ir/fused_attention_pass.cc b/paddle/fluid/framework/ir/fused_attention_pass.cc new file mode 100644 index 0000000000..771bf958d2 --- /dev/null +++ b/paddle/fluid/framework/ir/fused_attention_pass.cc @@ -0,0 +1,448 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/fused_attention_pass.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +PDNode* FusedAttentionPattern::operator()(PDNode* x, + bool pre_layer_norm, + bool post_layer_norm, + bool has_attn_mask, + bool do_dropout, + bool add_residual) { + // pre layer norm pattern + PDNode* pre_layer_norm_out_node{nullptr}; + if (pre_layer_norm) { + auto* pre_layer_norm_node = + pattern->NewNode(pre_layer_norm_op_repr())->assert_is_op("layer_norm"); + auto* pre_layer_norm_scale_node = + pattern->NewNode(pre_layer_norm_scale_repr()) + ->assert_is_op_input("layer_norm", "Scale"); + auto* pre_layer_norm_bias_node = + pattern->NewNode(pre_layer_norm_bias_repr()) + ->assert_is_op_input("layer_norm", "Bias"); + pre_layer_norm_out_node = pattern->NewNode(pre_layer_norm_out_repr()) + ->assert_is_op_output("layer_norm", "Y"); + auto* pre_layer_norm_mean_node = + pattern->NewNode(pre_layer_norm_mean_repr()) + ->assert_is_op_output("layer_norm", "Mean"); + auto* pre_layer_norm_variance_node = + pattern->NewNode(pre_layer_norm_variance_repr()) + ->assert_is_op_output("layer_norm", "Variance"); + pre_layer_norm_node + ->LinksFrom({x, pre_layer_norm_scale_node, pre_layer_norm_bias_node}) + .LinksTo({pre_layer_norm_out_node, + pre_layer_norm_mean_node, + pre_layer_norm_variance_node}); + } + + // fuse qkv pattern + auto* fuse_qkv_matmul_node = + pattern->NewNode(fuse_qkv_matmul_op_repr())->assert_is_op("matmul_v2"); + auto* fuse_qkv_matmul_w_node = pattern->NewNode(fuse_qkv_matmul_w_repr()) + ->assert_is_op_input("matmul_v2", "Y"); + auto* fuse_qkv_matmul_out_node = pattern->NewNode(fuse_qkv_matmul_out_repr()) + ->assert_is_op_output("matmul_v2"); + if (pre_layer_norm) { + pre_layer_norm_out_node->assert_is_op_input("matmul_v2", "X"); + fuse_qkv_matmul_node + ->LinksFrom({pre_layer_norm_out_node, fuse_qkv_matmul_w_node}) + .LinksTo({fuse_qkv_matmul_out_node}); + } else { + fuse_qkv_matmul_node->LinksFrom({x, fuse_qkv_matmul_w_node}) + .LinksTo({fuse_qkv_matmul_out_node}); + } + + auto* fuse_qkv_ele_add_node = pattern->NewNode(fuse_qkv_ele_add_op_repr()) + ->assert_is_op("elementwise_add"); + auto* fuse_qkv_ele_add_bias_node = + pattern->NewNode(fuse_qkv_ele_add_bias_repr()) + ->assert_is_op_input("elementwise_add", "Y"); + auto* fuse_qkv_ele_add_out_node = + pattern->NewNode(fuse_qkv_ele_add_out_repr()) + ->assert_is_op_output("elementwise_add"); + fuse_qkv_matmul_out_node->assert_is_op_input("elementwise_add", "X"); + fuse_qkv_ele_add_node + ->LinksFrom({fuse_qkv_matmul_out_node, fuse_qkv_ele_add_bias_node}) + .LinksTo({fuse_qkv_ele_add_out_node}); + + auto* fuse_qkv_reshape_node = + pattern->NewNode(fuse_qkv_reshape_op_repr())->assert_is_op("reshape2"); + auto* fuse_qkv_reshape_x_shape_node = + pattern->NewNode(fuse_qkv_reshape_x_shape_repr()) + ->assert_is_op_output("reshape2", "XShape"); + auto* fuse_qkv_reshape_out_node = + pattern->NewNode(fuse_qkv_reshape_out_repr()) + ->assert_is_op_output("reshape2"); + fuse_qkv_ele_add_out_node->assert_is_op_input("reshape2", "X"); + fuse_qkv_reshape_node->LinksFrom({fuse_qkv_ele_add_out_node}) + .LinksTo({fuse_qkv_reshape_x_shape_node, fuse_qkv_reshape_out_node}); + + auto* fuse_qkv_transpose_node = pattern->NewNode(fuse_qkv_transpose_op_repr()) + ->assert_is_op("transpose2"); + auto* fuse_qkv_transpose_x_shape_node = + pattern->NewNode(fuse_qkv_transpose_x_shape_repr()) + ->assert_is_op_output("transpose2", "XShape"); + auto* fuse_qkv_transpose_out_node = + pattern->NewNode(fuse_qkv_transpose_out_repr()) + ->assert_is_op_output("transpose2"); + fuse_qkv_reshape_out_node->assert_is_op_input("transpose2", "X"); + fuse_qkv_transpose_node->LinksFrom({fuse_qkv_reshape_out_node}) + .LinksTo({fuse_qkv_transpose_x_shape_node, fuse_qkv_transpose_out_node}); + + auto* fuse_qkv_split_node = + pattern->NewNode(fuse_qkv_split_op_repr())->assert_is_op("split"); + auto* fuse_qkv_split_out_q_node = + pattern->NewNode(fuse_qkv_split_out_q_repr()) + ->assert_is_op_output("split"); + auto* fuse_qkv_split_out_k_node = + pattern->NewNode(fuse_qkv_split_out_k_repr()) + ->assert_is_op_output("split"); + auto* fuse_qkv_split_out_v_node = + pattern->NewNode(fuse_qkv_split_out_v_repr()) + ->assert_is_op_output("split"); + fuse_qkv_transpose_out_node->assert_is_op_input("split", "X"); + fuse_qkv_split_node->LinksFrom({fuse_qkv_transpose_out_node}) + .LinksTo({fuse_qkv_split_out_q_node, + fuse_qkv_split_out_k_node, + fuse_qkv_split_out_v_node}); + + // core attention pattern + auto* qk_matmul_node = + pattern->NewNode(qk_matmul_op_repr())->assert_is_op("matmul_v2"); + auto* qk_matmul_out_node = + pattern->NewNode(qk_matmul_out_repr())->assert_is_op_output("matmul_v2"); + fuse_qkv_split_out_q_node->assert_is_op_input("matmul_v2", "X"); + fuse_qkv_split_out_k_node->assert_is_op_input("matmul_v2", "Y"); + qk_matmul_node + ->LinksFrom({fuse_qkv_split_out_q_node, fuse_qkv_split_out_k_node}) + .LinksTo({qk_matmul_out_node}); + + auto* qk_scale_node = + pattern->NewNode(qk_scale_op_repr())->assert_is_op("scale"); + auto* qk_scale_out_node = + pattern->NewNode(qk_scale_out_repr())->assert_is_op_output("scale"); + qk_matmul_out_node->assert_is_op_input("scale", "X"); + qk_scale_node->LinksFrom({qk_matmul_out_node}).LinksTo({qk_scale_out_node}); + + PDNode* add_mask_ele_add_out_node{nullptr}; + if (has_attn_mask) { + auto* add_mask_ele_add_node = pattern->NewNode(add_mask_ele_add_op_repr()) + ->assert_is_op("elementwise_add"); + auto* add_mask_ele_add_mask_node = + pattern->NewNode(add_mask_ele_add_mask_repr()) + ->assert_is_op_input("elementwise_add", "Y"); + add_mask_ele_add_out_node = pattern->NewNode(add_mask_ele_add_out_repr()) + ->assert_is_op_output("elementwise_add"); + qk_scale_out_node->assert_is_op_input("elementwise_add", "X"); + add_mask_ele_add_node + ->LinksFrom({qk_scale_out_node, add_mask_ele_add_mask_node}) + .LinksTo({add_mask_ele_add_out_node}); + } + + auto* qk_softmax_node = + pattern->NewNode(qk_softmax_op_repr())->assert_is_op("softmax"); + auto* qk_softmax_out_node = + pattern->NewNode(qk_softmax_out_repr())->assert_is_op_output("softmax"); + if (has_attn_mask) { + add_mask_ele_add_out_node->assert_is_op_input("softmax", "X"); + qk_softmax_node->LinksFrom({add_mask_ele_add_out_node}) + .LinksTo({qk_softmax_out_node}); + } else { + qk_scale_out_node->assert_is_op_input("softmax", "X"); + qk_softmax_node->LinksFrom({qk_scale_out_node}) + .LinksTo({qk_softmax_out_node}); + } + + PDNode* attn_dropout_out_node{nullptr}; + if (do_dropout) { + auto* attn_dropout_node = + pattern->NewNode(attn_dropout_op_repr())->assert_is_op("dropout"); + auto* attn_dropout_mask_node = pattern->NewNode(attn_dropout_mask_repr()) + ->assert_is_op_output("dropout", "Mask"); + attn_dropout_out_node = pattern->NewNode(attn_dropout_out_repr()) + ->assert_is_op_output("dropout"); + qk_softmax_out_node->assert_is_op_input("dropout", "X"); + attn_dropout_node->LinksFrom({qk_softmax_out_node}) + .LinksTo({attn_dropout_mask_node, attn_dropout_out_node}); + } + + auto* qkv_matmul_node = + pattern->NewNode(qkv_matmul_op_repr())->assert_is_op("matmul_v2"); + auto* qkv_matmul_out_node = + pattern->NewNode(qkv_matmul_out_repr())->assert_is_op_output("matmul_v2"); + fuse_qkv_split_out_v_node->assert_is_op_input("matmul_v2", "Y"); + if (do_dropout) { + attn_dropout_out_node->assert_is_op_input("matmul_v2", "X"); + qkv_matmul_node + ->LinksFrom({attn_dropout_out_node, fuse_qkv_split_out_v_node}) + .LinksTo({qkv_matmul_out_node}); + } else { + qk_softmax_out_node->assert_is_op_input("matmul_v2", "X"); + qkv_matmul_node->LinksFrom({qk_softmax_out_node, fuse_qkv_split_out_v_node}) + .LinksTo({qkv_matmul_out_node}); + } + + auto* qkv_transpose_node = + pattern->NewNode(qkv_transpose_op_repr())->assert_is_op("transpose2"); + auto* qkv_transpose_x_shape_node = + pattern->NewNode(qkv_transpose_x_shape_repr()) + ->assert_is_op_output("transpose2", "XShape"); + auto* qkv_transpose_out_node = pattern->NewNode(qkv_transpose_out_repr()) + ->assert_is_op_output("transpose2"); + qkv_matmul_out_node->assert_is_op_input("transpose2", "X"); + qkv_transpose_node->LinksFrom({qkv_matmul_out_node}) + .LinksTo({qkv_transpose_x_shape_node, qkv_transpose_out_node}); + + auto* qkv_reshape_node = + pattern->NewNode(qkv_reshape_op_repr())->assert_is_op("reshape2"); + auto* qkv_reshape_x_shape_node = + pattern->NewNode(qkv_reshape_x_shape_repr()) + ->assert_is_op_output("reshape2", "XShape"); + auto* qkv_reshape_out_node = + pattern->NewNode(qkv_reshape_out_repr())->assert_is_op_output("reshape2"); + qkv_transpose_out_node->assert_is_op_input("reshape2", "X"); + qkv_reshape_node->LinksFrom({qkv_transpose_out_node}) + .LinksTo({qkv_reshape_x_shape_node, qkv_reshape_out_node}); + + // out linear pattern + auto* out_linear_matmul_node = + pattern->NewNode(out_linear_matmul_op_repr())->assert_is_op("matmul_v2"); + auto* out_linear_matmul_w_node = pattern->NewNode(out_linear_matmul_w_repr()) + ->assert_is_op_input("matmul_v2", "Y"); + auto* out_linear_matmul_out_node = + pattern->NewNode(out_linear_matmul_out_repr()) + ->assert_is_op_output("matmul_v2"); + qkv_reshape_out_node->assert_is_op_input("matmul_v2", "X"); + out_linear_matmul_node + ->LinksFrom({qkv_reshape_out_node, out_linear_matmul_w_node}) + .LinksTo({out_linear_matmul_out_node}); + + auto* out_linear_ele_add_node = pattern->NewNode(out_linear_ele_add_op_repr()) + ->assert_is_op("elementwise_add"); + auto* out_linear_ele_add_bias_node = + pattern->NewNode(out_linear_ele_add_bias_repr()) + ->assert_is_op_input("elementwise_add", "Y"); + auto* out_linear_ele_add_out_node = + pattern->NewNode(out_linear_ele_add_out_repr()) + ->assert_is_op_output("elementwise_add"); + out_linear_matmul_out_node->assert_is_op_input("elementwise_add", "X"); + out_linear_ele_add_node + ->LinksFrom({out_linear_matmul_out_node, out_linear_ele_add_bias_node}) + .LinksTo({out_linear_ele_add_out_node}); + + auto* out_linear_dropout_node = + pattern->NewNode(out_linear_dropout_op_repr())->assert_is_op("dropout"); + auto* out_linear_dropout_mask_node = + pattern->NewNode(out_linear_dropout_mask_repr()) + ->assert_is_op_output("dropout", "Mask"); + auto* out_linear_dropout_out_node = + pattern->NewNode(out_linear_dropout_out_repr()) + ->assert_is_op_output("dropout"); + out_linear_ele_add_out_node->assert_is_op_input("dropout", "X"); + out_linear_dropout_node->LinksFrom({out_linear_ele_add_out_node}) + .LinksTo({out_linear_dropout_mask_node, out_linear_dropout_out_node}); + + if (!add_residual && !post_layer_norm) { + return out_linear_dropout_out_node; + } + + // add residual + PDNode* residual_ele_add_out_node{nullptr}; + if (add_residual) { + // this kind of pattern only support `residual + dropout_out`, since we have + // to fix X and Y + auto* residual_ele_add_node = pattern->NewNode(residual_ele_add_op_repr()) + ->assert_is_op("elementwise_add"); + residual_ele_add_out_node = pattern->NewNode(residual_ele_add_out_repr()) + ->assert_is_op_output("elementwise_add"); + out_linear_dropout_out_node->assert_is_op_input("elementwise_add", "Y"); + residual_ele_add_node->LinksFrom({x, out_linear_dropout_out_node}) + .LinksTo({residual_ele_add_out_node}); + + if (!post_layer_norm) { + return residual_ele_add_out_node; + } + } + + // post layer norm + auto* post_layer_norm_node = + pattern->NewNode(post_layer_norm_op_repr())->assert_is_op("layer_norm"); + auto* post_layer_norm_scale_node = + pattern->NewNode(post_layer_norm_scale_repr()) + ->assert_is_op_input("layer_norm", "Scale"); + auto* post_layer_norm_bias_node = + pattern->NewNode(post_layer_norm_bias_repr()) + ->assert_is_op_input("layer_norm", "Bias"); + auto* post_layer_norm_out_node = pattern->NewNode(post_layer_norm_out_repr()) + ->assert_is_op_output("layer_norm", "Y"); + auto* post_layer_norm_mean_node = + pattern->NewNode(post_layer_norm_mean_repr()) + ->assert_is_op_output("layer_norm", "Mean"); + auto* post_layer_norm_variance_node = + pattern->NewNode(post_layer_norm_variance_repr()) + ->assert_is_op_output("layer_norm", "Variance"); + if (add_residual) { + residual_ele_add_out_node->assert_is_op_input("layer_norm", "X"); + post_layer_norm_node + ->LinksFrom({residual_ele_add_out_node, + post_layer_norm_scale_node, + post_layer_norm_bias_node}) + .LinksTo({post_layer_norm_out_node, + post_layer_norm_mean_node, + post_layer_norm_variance_node}); + } else { + out_linear_dropout_out_node->assert_is_op_input("layer_norm", "X"); + post_layer_norm_node + ->LinksFrom({out_linear_dropout_out_node, + post_layer_norm_scale_node, + post_layer_norm_bias_node}) + .LinksTo({post_layer_norm_out_node, + post_layer_norm_mean_node, + post_layer_norm_variance_node}); + } + + return post_layer_norm_out_node; +} + +PDNode* FusedAttentionGradPattern::operator()(PDNode* x, + bool pre_layer_norm, + bool post_layer_norm, + bool has_attn_mask, + bool do_dropout, + bool add_residual) { + // TODO(Yuang Liu): finish the backward pattern + return nullptr; +} + +} // namespace patterns + +void FusedAttentionsPass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + + graph = PreMaskDropResPostFwd(graph); + graph = PreMaskDropResPostBwd(graph); +} + +ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const { + GraphPatternDetector gpd; + auto* x = gpd.mutable_pattern() + ->NewNode(patterns::PDNodeName(name_scope_, "x")) + ->AsInput() + ->assert_is_op_input("layer_norm", "X"); + patterns::FusedAttentionPattern fused_attention_pattern( + gpd.mutable_pattern(), "fused_attention_pattern"); + + fused_attention_pattern(x, + /* pre_layer_norm */ true, + /* post_layer_norm */ true, + /* has_attn_mask */ true, + /* do_dropout */ true, + /* add_residual */ true); + + int found_fused_attention = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(3) << "handle FusedMultiHeadAttention pass's fusion"; + + GET_IR_NODE_FROM_SUBGRAPH( + pre_layer_norm_op_node, pre_layer_norm_op, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + fuse_qkv_matmul_op_node, fuse_qkv_matmul_op, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + fuse_qkv_ele_add_op_node, fuse_qkv_ele_add_op, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + fuse_qkv_reshape_op_node, fuse_qkv_reshape_op, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_transpose_op_node, + fuse_qkv_transpose_op, + fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + fuse_qkv_split_op_node, fuse_qkv_split_op, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + qk_matmul_op_node, qk_matmul_op, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + qk_scale_op_node, qk_scale_op, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + add_mask_ele_add_op_node, add_mask_ele_add_op, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + qk_softmax_op_node, qk_softmax_op, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + attn_dropout_op_node, attn_dropout_op, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + qkv_matmul_op_node, qkv_matmul_op, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + qkv_transpose_op_node, qkv_transpose_op, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + qkv_reshape_op_node, qkv_reshape_op, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH(out_linear_matmul_op_node, + out_linear_matmul_op, + fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_op_node, + out_linear_ele_add_op, + fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH(out_linear_dropout_op_node, + out_linear_dropout_op, + fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + residual_ele_add_op_node, residual_ele_add_op, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + post_layer_norm_op_node, post_layer_norm_op, fused_attention_pattern); + + // TODO(Yuang Liu): finish the handler + + GraphSafeRemoveNodes(g, + {pre_layer_norm_op_node, + fuse_qkv_matmul_op_node, + fuse_qkv_ele_add_op_node, + fuse_qkv_reshape_op_node, + fuse_qkv_transpose_op_node, + fuse_qkv_split_op_node, + qk_matmul_op_node, + qk_scale_op_node, + add_mask_ele_add_op_node, + qk_softmax_op_node, + attn_dropout_op_node, + qkv_matmul_op_node, + qkv_transpose_op_node, + qkv_reshape_op_node, + out_linear_matmul_op_node, + out_linear_ele_add_op_node, + out_linear_dropout_op_node, + residual_ele_add_op_node, + post_layer_norm_op_node}); + found_fused_attention++; + }; + + gpd(graph, handler); + AddStatis(found_fused_attention); + + return graph; +} + +ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const { + // TODO(Yuang Liu): finish the pass + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fused_attention_pass, paddle::framework::ir::FusedAttentionsPass); diff --git a/paddle/fluid/framework/ir/fused_attention_pass.h b/paddle/fluid/framework/ir/fused_attention_pass.h new file mode 100644 index 0000000000..5ec1aac41e --- /dev/null +++ b/paddle/fluid/framework/ir/fused_attention_pass.h @@ -0,0 +1,176 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +// Declare patterns for multi head attention. +// Can detect: +// 1. Pre layer norm, post layer norm or sandwich layer norm. +// 2. Add attn mask for qk product before the softmax or not. +// 3. Do attn dropout or not. +// 4. Add residual to the out linear result or not. +struct FusedAttentionPattern : public PatternBase { + FusedAttentionPattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "fused_attention_pattern") {} + + PDNode* operator()(PDNode* x, + bool pre_layer_norm, // do pre ln or not + bool post_layer_norm, // do post ln or not + bool has_attn_mask, // add attn mask to qk or not + bool do_dropout, // dropout the softmax(qk) or not + bool add_residual); // add residual to out linear or not + + // pre layer norm + PATTERN_DECL_NODE(pre_layer_norm_op); + PATTERN_DECL_NODE(pre_layer_norm_scale); + PATTERN_DECL_NODE(pre_layer_norm_bias); + PATTERN_DECL_NODE(pre_layer_norm_out); + PATTERN_DECL_NODE(pre_layer_norm_mean); + PATTERN_DECL_NODE(pre_layer_norm_variance); + + // fuse qkv projection + PATTERN_DECL_NODE(fuse_qkv_matmul_op); + PATTERN_DECL_NODE(fuse_qkv_matmul_w); + PATTERN_DECL_NODE(fuse_qkv_matmul_out); + + PATTERN_DECL_NODE(fuse_qkv_ele_add_op); + PATTERN_DECL_NODE(fuse_qkv_ele_add_bias); + PATTERN_DECL_NODE(fuse_qkv_ele_add_out); + + PATTERN_DECL_NODE(fuse_qkv_reshape_op); + PATTERN_DECL_NODE(fuse_qkv_reshape_out); + PATTERN_DECL_NODE(fuse_qkv_reshape_x_shape); + + PATTERN_DECL_NODE(fuse_qkv_transpose_op); + PATTERN_DECL_NODE(fuse_qkv_transpose_out); + PATTERN_DECL_NODE(fuse_qkv_transpose_x_shape); + + PATTERN_DECL_NODE(fuse_qkv_split_op); + PATTERN_DECL_NODE(fuse_qkv_split_out_q); // q + PATTERN_DECL_NODE(fuse_qkv_split_out_k); // k + PATTERN_DECL_NODE(fuse_qkv_split_out_v); // v + + // core attention + PATTERN_DECL_NODE(qk_matmul_op); + PATTERN_DECL_NODE(qk_matmul_out); + + PATTERN_DECL_NODE(qk_scale_op); + PATTERN_DECL_NODE(qk_scale_out); + + PATTERN_DECL_NODE(add_mask_ele_add_op); + PATTERN_DECL_NODE(add_mask_ele_add_mask); + PATTERN_DECL_NODE(add_mask_ele_add_out); + + PATTERN_DECL_NODE(qk_softmax_op); + PATTERN_DECL_NODE(qk_softmax_out); + + PATTERN_DECL_NODE(attn_dropout_op); + PATTERN_DECL_NODE(attn_dropout_out); + PATTERN_DECL_NODE(attn_dropout_mask); + + PATTERN_DECL_NODE(qkv_matmul_op); + PATTERN_DECL_NODE(qkv_matmul_out); + + PATTERN_DECL_NODE(qkv_transpose_op); + PATTERN_DECL_NODE(qkv_transpose_out); + PATTERN_DECL_NODE(qkv_transpose_x_shape); + + PATTERN_DECL_NODE(qkv_reshape_op); + PATTERN_DECL_NODE(qkv_reshape_out); + PATTERN_DECL_NODE(qkv_reshape_x_shape); + + // out linear + PATTERN_DECL_NODE(out_linear_matmul_op); + PATTERN_DECL_NODE(out_linear_matmul_w); + PATTERN_DECL_NODE(out_linear_matmul_out); + + PATTERN_DECL_NODE(out_linear_ele_add_op); + PATTERN_DECL_NODE(out_linear_ele_add_bias); + PATTERN_DECL_NODE(out_linear_ele_add_out); + + PATTERN_DECL_NODE(out_linear_dropout_op); + PATTERN_DECL_NODE(out_linear_dropout_out); + PATTERN_DECL_NODE(out_linear_dropout_mask); + + // residual + PATTERN_DECL_NODE(residual_ele_add_op); + PATTERN_DECL_NODE(residual_ele_add_out); + + // post layer norm + PATTERN_DECL_NODE(post_layer_norm_op); + PATTERN_DECL_NODE(post_layer_norm_scale); + PATTERN_DECL_NODE(post_layer_norm_bias); + PATTERN_DECL_NODE(post_layer_norm_out); + PATTERN_DECL_NODE(post_layer_norm_mean); + PATTERN_DECL_NODE(post_layer_norm_variance); +}; + +// Declare the grad pattern for multi head attention +struct FusedAttentionGradPattern : public PatternBase { + FusedAttentionGradPattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "fused_attention_pattern") {} + + PDNode* operator()(PDNode* x, + bool pre_layer_norm, // pre ln + bool post_layer_norm, // post ln + bool has_attn_mask, // add attn mask to qk or not + bool do_dropout, // dropout the softmax(qk) or not + bool add_residual); // add residual to out linear or not + + // TODO(Yuang Liu): add backward pattern +}; + +} // namespace patterns + +class FusedAttentionsPass : public FusePassBase { + public: + virtual ~FusedAttentionsPass() {} + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{"fused_attention_pass"}; + + private: + // The name rule for the helper function. + // The function name will contain at most five parts in order: + // 1. Do pre layer norm? [Pre] + // 2. Add mask in the core attention part? [Mask] + // 3. Do dropout in the core attention part? [Drop] + // 4. Add residual? [Res] + // 5. Do post layer norm? [Post] + // 6. Forward or Backward? [Fwd/Bwd] + // If true, the function name will have an abbreviation part. + // If false, the function name won't contain an abbreviation for it. + + ir::Graph* PreMaskDropResPostFwd(Graph* graph) const; + + ir::Graph* PreMaskDropResPostBwd(Graph* graph) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/pybind/parallel_executor.cc b/paddle/fluid/pybind/parallel_executor.cc index 962bdd736f..9ca2682462 100644 --- a/paddle/fluid/pybind/parallel_executor.cc +++ b/paddle/fluid/pybind/parallel_executor.cc @@ -714,6 +714,32 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT build_strategy = static.BuildStrategy() build_strategy.fuse_gemm_epilogue = True )DOC") + .def_property( + "fused_attention", + [](const BuildStrategy &self) { return self.fused_attention_; }, + [](BuildStrategy &self, bool b) { + PADDLE_ENFORCE_NE(self.IsFinalized(), + true, + platform::errors::PreconditionNotMet( + "BuildStrategy has been finlaized, cannot be " + "configured again.")); + self.fused_attention_ = b; + }, + R"DOC((bool, optional): fused_attention indicate whether + to fuse the whole multi head attention part with one op, + it may make the execution faster. Default is False. + + Examples: + .. code-block:: python + + import paddle + import paddle.static as static + + paddle.enable_static() + + build_strategy = static.BuildStrategy() + build_strategy.fused_attention = True + )DOC") .def_property( "fuse_bn_act_ops", [](const BuildStrategy &self) { return self.fuse_bn_act_ops_; }, diff --git a/python/paddle/distributed/passes/cpp_pass.py b/python/paddle/distributed/passes/cpp_pass.py index 9201682d89..3a791610a5 100755 --- a/python/paddle/distributed/passes/cpp_pass.py +++ b/python/paddle/distributed/passes/cpp_pass.py @@ -71,6 +71,19 @@ class FuseReluDepthwiseConvPass(CPPPassWrapper): return PassType.FUSION_OPT +@register_pass("fused_attention") +class FusedAttentionPass(CPPPassWrapper): + def __init__(self): + super().__init__() + + @property + def cpp_name(self): + return "fused_attention_pass" + + def _type(self): + return PassType.FUSION_OPT + + @register_pass("fuse_gemm_epilogue") class FuseGemmEpiloguePass(CPPPassWrapper): def __init__(self): diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 1723ae48d1..1dd97cdccb 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -76,6 +76,7 @@ if(NOT WITH_GPU) list(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api) + list(REMOVE_ITEM TEST_OPS test_fused_attention_pass) endif() list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py b/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py new file mode 100644 index 0000000000..ff2e2f7328 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py @@ -0,0 +1,164 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +import paddle.fluid.core as core +import paddle.nn.functional as F +from paddle.distributed.passes import PassManager, new_pass + +paddle.enable_static() + + +class MultiHeadAttention(paddle.nn.Layer): + def __init__( + self, + embed_dim, + num_heads, + add_residual=True, + pre_ln=True, + post_ln=False, + attn_dropout=True, + ): + super(MultiHeadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = embed_dim + self.vdim = embed_dim + self.num_heads = num_heads + + self.add_residual = add_residual + self.pre_ln = pre_ln + self.post_ln = post_ln + self.attn_dropout = attn_dropout + + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.norm1 = paddle.nn.LayerNorm(embed_dim, epsilon=1e-5) + self.norm2 = paddle.nn.LayerNorm(embed_dim, epsilon=1e-5) + + self.qkv_proj = paddle.nn.Linear(embed_dim, 3 * embed_dim) + self.out_proj = paddle.nn.Linear(embed_dim, embed_dim) + self.dropout = paddle.nn.Dropout(0.1, mode="upscale_in_train") + + def forward(self, x, attn_mask=None): + residual = x + + if self.pre_ln: + # pre layer norm + x = self.norm1(x) + + # compute qkv + qkv = self.qkv_proj(x) + qkv = paddle.reshape(qkv, [0, 0, self.num_heads, 3 * self.head_dim]) + qkv = paddle.transpose(qkv, [0, 2, 1, 3]) + q, k, v = paddle.split(qkv, num_or_sections=3, axis=-1) + + # compute core attention + product = paddle.matmul(x=q, y=k, transpose_y=True) + product = paddle.scale(product, scale=self.head_dim**-0.5) + if attn_mask is not None: + product = product + attn_mask + weights = F.softmax(product) + if self.attn_dropout: + weights = F.dropout( + weights, 0.1, training=self.training, mode="upscale_in_train" + ) + out = paddle.matmul(weights, v) + out = paddle.transpose(out, perm=[0, 2, 1, 3]) + out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.out_proj(out) + out = self.dropout(out) + if self.add_residual: + out = residual + out + + if self.post_ln: + # post layer norm + out = self.norm2(out) + + return out + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestFusedAttentionPass(unittest.TestCase): + def setUp(self): + self.add_residual = True + self.pre_ln = True + self.post_ln = True + self.attn_dropout = True + self.add_mask = True + + def test_pass(self): + batch_size = 2 + seq_len = 1024 + hidden_size = 768 + num_heads = 12 + + x_data = np.random.rand(batch_size, seq_len, hidden_size).astype( + 'float32' + ) + mask_data = np.random.rand( + batch_size, num_heads, seq_len, seq_len + ).astype('float32') + + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + + with paddle.static.program_guard(main_prog, startup_prog): + data = paddle.static.data( + name="x", + shape=[-1, seq_len, hidden_size], + dtype='float32', + ) + if self.add_mask: + attn_mask = paddle.static.data( + name="attn_mask", + shape=[-1, num_heads, seq_len, seq_len], + dtype='float32', + ) + else: + attn_mask = None + multi_head_attn = MultiHeadAttention( + hidden_size, + num_heads, + add_residual=self.add_residual, + pre_ln=self.pre_ln, + post_ln=self.post_ln, + attn_dropout=self.attn_dropout, + ) + out = multi_head_attn(data, attn_mask) + loss = paddle.mean(out) + + sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001) + sgd_optimizer.minimize(loss) + + pass_manager = PassManager([new_pass("fused_attention")]) + pass_manager.apply([main_prog], [startup_prog]) + + ops = main_prog.global_block().ops + assert ops[0].type == 'reduce_mean' + + +if __name__ == "__main__": + np.random.seed(0) + unittest.main() -- GitLab