未验证 提交 b0ece266 编写于 作者: Y Yuang Liu 提交者: GitHub

[Fuse attention pass] Forward pattern. (#49621)

上级 13992de7
...@@ -382,6 +382,7 @@ set(IR_PASS_DEPS ...@@ -382,6 +382,7 @@ set(IR_PASS_DEPS
graph_to_program_pass graph_to_program_pass
fix_op_run_order_pass fix_op_run_order_pass
fuse_gemm_epilogue_pass fuse_gemm_epilogue_pass
fused_attention_pass
delete_dropout_op_pass) delete_dropout_op_pass)
if(WITH_CINN) if(WITH_CINN)
......
...@@ -187,6 +187,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -187,6 +187,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass"); AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass");
#endif #endif
#ifdef PADDLE_WITH_CUDA
AppendPassWithCheck(strategy_.fused_attention_, "fused_attention_pass");
#endif
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) #if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060)
AppendPassWithCheck(strategy_.fuse_gemm_epilogue_, AppendPassWithCheck(strategy_.fuse_gemm_epilogue_,
"fuse_gemm_epilogue_pass"); "fuse_gemm_epilogue_pass");
...@@ -519,6 +523,9 @@ USE_PASS(fuse_all_reduce_op_pass); ...@@ -519,6 +523,9 @@ USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass); USE_PASS(runtime_context_cache_pass);
USE_PASS(add_reader_dependency_pass); USE_PASS(add_reader_dependency_pass);
USE_PASS(delete_dropout_op_x_pass); USE_PASS(delete_dropout_op_x_pass);
#ifdef PADDLE_WITH_CUDA
USE_PASS(fused_attention_pass);
#endif
#ifdef PADDLE_WITH_CINN #ifdef PADDLE_WITH_CINN
USE_PASS(build_cinn_pass); USE_PASS(build_cinn_pass);
#endif #endif
......
...@@ -129,6 +129,8 @@ struct BuildStrategy { ...@@ -129,6 +129,8 @@ struct BuildStrategy {
bool sync_batch_norm_{false}; bool sync_batch_norm_{false};
// Fuse GEMM+Epilogue via cublasLt epilogue. // Fuse GEMM+Epilogue via cublasLt epilogue.
bool fuse_gemm_epilogue_{false}; bool fuse_gemm_epilogue_{false};
// Fused multi head attention
bool fused_attention_{false};
// mkldnn_enabled_op_types specify the operator type list to // mkldnn_enabled_op_types specify the operator type list to
// use MKLDNN acceleration. It is null in default, means // use MKLDNN acceleration. It is null in default, means
...@@ -261,6 +263,7 @@ inline std::ostream &operator<<(std::ostream &os, ...@@ -261,6 +263,7 @@ inline std::ostream &operator<<(std::ostream &os,
os << "fuse_broadcast_ops_: " << strategy.fuse_broadcast_ops_ << std::endl; os << "fuse_broadcast_ops_: " << strategy.fuse_broadcast_ops_ << std::endl;
os << "sync_batch_norm_: " << strategy.sync_batch_norm_ << std::endl; os << "sync_batch_norm_: " << strategy.sync_batch_norm_ << std::endl;
os << "fuse_gemm_epilogue_: " << strategy.fuse_gemm_epilogue_ << std::endl; os << "fuse_gemm_epilogue_: " << strategy.fuse_gemm_epilogue_ << std::endl;
os << "fused_attention_: " << strategy.fused_attention_ << std::endl;
os << "mkldnn_enabled_op_types_: "; os << "mkldnn_enabled_op_types_: ";
for (auto str : strategy.mkldnn_enabled_op_types_) { for (auto str : strategy.mkldnn_enabled_op_types_) {
os << str << ", "; os << str << ", ";
......
...@@ -124,6 +124,7 @@ message BuildStrategy { ...@@ -124,6 +124,7 @@ message BuildStrategy {
optional int32 reduce_strategy = 15 [ default = 0 ]; optional int32 reduce_strategy = 15 [ default = 0 ];
optional bool fuse_gemm_epilogue = 16 [ default = false ]; optional bool fuse_gemm_epilogue = 16 [ default = false ];
optional string debug_graphviz_path = 17; optional string debug_graphviz_path = 17;
optional bool fused_attention = 18 [ default = false];
} }
message ExecutionStrategy { message ExecutionStrategy {
......
...@@ -223,6 +223,10 @@ cc_library( ...@@ -223,6 +223,10 @@ cc_library(
fuse_gemm_epilogue_pass fuse_gemm_epilogue_pass
SRCS fuse_gemm_epilogue_pass.cc SRCS fuse_gemm_epilogue_pass.cc
DEPS pass graph_pattern_detector) DEPS pass graph_pattern_detector)
cc_library(
fused_attention_pass
SRCS fused_attention_pass.cc
DEPS pass graph_pattern_detector)
cc_library( cc_library(
fuse_relu_depthwise_conv_pass fuse_relu_depthwise_conv_pass
SRCS fuse_relu_depthwise_conv_pass.cc SRCS fuse_relu_depthwise_conv_pass.cc
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fused_attention_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
PDNode* FusedAttentionPattern::operator()(PDNode* x,
bool pre_layer_norm,
bool post_layer_norm,
bool has_attn_mask,
bool do_dropout,
bool add_residual) {
// pre layer norm pattern
PDNode* pre_layer_norm_out_node{nullptr};
if (pre_layer_norm) {
auto* pre_layer_norm_node =
pattern->NewNode(pre_layer_norm_op_repr())->assert_is_op("layer_norm");
auto* pre_layer_norm_scale_node =
pattern->NewNode(pre_layer_norm_scale_repr())
->assert_is_op_input("layer_norm", "Scale");
auto* pre_layer_norm_bias_node =
pattern->NewNode(pre_layer_norm_bias_repr())
->assert_is_op_input("layer_norm", "Bias");
pre_layer_norm_out_node = pattern->NewNode(pre_layer_norm_out_repr())
->assert_is_op_output("layer_norm", "Y");
auto* pre_layer_norm_mean_node =
pattern->NewNode(pre_layer_norm_mean_repr())
->assert_is_op_output("layer_norm", "Mean");
auto* pre_layer_norm_variance_node =
pattern->NewNode(pre_layer_norm_variance_repr())
->assert_is_op_output("layer_norm", "Variance");
pre_layer_norm_node
->LinksFrom({x, pre_layer_norm_scale_node, pre_layer_norm_bias_node})
.LinksTo({pre_layer_norm_out_node,
pre_layer_norm_mean_node,
pre_layer_norm_variance_node});
}
// fuse qkv pattern
auto* fuse_qkv_matmul_node =
pattern->NewNode(fuse_qkv_matmul_op_repr())->assert_is_op("matmul_v2");
auto* fuse_qkv_matmul_w_node = pattern->NewNode(fuse_qkv_matmul_w_repr())
->assert_is_op_input("matmul_v2", "Y");
auto* fuse_qkv_matmul_out_node = pattern->NewNode(fuse_qkv_matmul_out_repr())
->assert_is_op_output("matmul_v2");
if (pre_layer_norm) {
pre_layer_norm_out_node->assert_is_op_input("matmul_v2", "X");
fuse_qkv_matmul_node
->LinksFrom({pre_layer_norm_out_node, fuse_qkv_matmul_w_node})
.LinksTo({fuse_qkv_matmul_out_node});
} else {
fuse_qkv_matmul_node->LinksFrom({x, fuse_qkv_matmul_w_node})
.LinksTo({fuse_qkv_matmul_out_node});
}
auto* fuse_qkv_ele_add_node = pattern->NewNode(fuse_qkv_ele_add_op_repr())
->assert_is_op("elementwise_add");
auto* fuse_qkv_ele_add_bias_node =
pattern->NewNode(fuse_qkv_ele_add_bias_repr())
->assert_is_op_input("elementwise_add", "Y");
auto* fuse_qkv_ele_add_out_node =
pattern->NewNode(fuse_qkv_ele_add_out_repr())
->assert_is_op_output("elementwise_add");
fuse_qkv_matmul_out_node->assert_is_op_input("elementwise_add", "X");
fuse_qkv_ele_add_node
->LinksFrom({fuse_qkv_matmul_out_node, fuse_qkv_ele_add_bias_node})
.LinksTo({fuse_qkv_ele_add_out_node});
auto* fuse_qkv_reshape_node =
pattern->NewNode(fuse_qkv_reshape_op_repr())->assert_is_op("reshape2");
auto* fuse_qkv_reshape_x_shape_node =
pattern->NewNode(fuse_qkv_reshape_x_shape_repr())
->assert_is_op_output("reshape2", "XShape");
auto* fuse_qkv_reshape_out_node =
pattern->NewNode(fuse_qkv_reshape_out_repr())
->assert_is_op_output("reshape2");
fuse_qkv_ele_add_out_node->assert_is_op_input("reshape2", "X");
fuse_qkv_reshape_node->LinksFrom({fuse_qkv_ele_add_out_node})
.LinksTo({fuse_qkv_reshape_x_shape_node, fuse_qkv_reshape_out_node});
auto* fuse_qkv_transpose_node = pattern->NewNode(fuse_qkv_transpose_op_repr())
->assert_is_op("transpose2");
auto* fuse_qkv_transpose_x_shape_node =
pattern->NewNode(fuse_qkv_transpose_x_shape_repr())
->assert_is_op_output("transpose2", "XShape");
auto* fuse_qkv_transpose_out_node =
pattern->NewNode(fuse_qkv_transpose_out_repr())
->assert_is_op_output("transpose2");
fuse_qkv_reshape_out_node->assert_is_op_input("transpose2", "X");
fuse_qkv_transpose_node->LinksFrom({fuse_qkv_reshape_out_node})
.LinksTo({fuse_qkv_transpose_x_shape_node, fuse_qkv_transpose_out_node});
auto* fuse_qkv_split_node =
pattern->NewNode(fuse_qkv_split_op_repr())->assert_is_op("split");
auto* fuse_qkv_split_out_q_node =
pattern->NewNode(fuse_qkv_split_out_q_repr())
->assert_is_op_output("split");
auto* fuse_qkv_split_out_k_node =
pattern->NewNode(fuse_qkv_split_out_k_repr())
->assert_is_op_output("split");
auto* fuse_qkv_split_out_v_node =
pattern->NewNode(fuse_qkv_split_out_v_repr())
->assert_is_op_output("split");
fuse_qkv_transpose_out_node->assert_is_op_input("split", "X");
fuse_qkv_split_node->LinksFrom({fuse_qkv_transpose_out_node})
.LinksTo({fuse_qkv_split_out_q_node,
fuse_qkv_split_out_k_node,
fuse_qkv_split_out_v_node});
// core attention pattern
auto* qk_matmul_node =
pattern->NewNode(qk_matmul_op_repr())->assert_is_op("matmul_v2");
auto* qk_matmul_out_node =
pattern->NewNode(qk_matmul_out_repr())->assert_is_op_output("matmul_v2");
fuse_qkv_split_out_q_node->assert_is_op_input("matmul_v2", "X");
fuse_qkv_split_out_k_node->assert_is_op_input("matmul_v2", "Y");
qk_matmul_node
->LinksFrom({fuse_qkv_split_out_q_node, fuse_qkv_split_out_k_node})
.LinksTo({qk_matmul_out_node});
auto* qk_scale_node =
pattern->NewNode(qk_scale_op_repr())->assert_is_op("scale");
auto* qk_scale_out_node =
pattern->NewNode(qk_scale_out_repr())->assert_is_op_output("scale");
qk_matmul_out_node->assert_is_op_input("scale", "X");
qk_scale_node->LinksFrom({qk_matmul_out_node}).LinksTo({qk_scale_out_node});
PDNode* add_mask_ele_add_out_node{nullptr};
if (has_attn_mask) {
auto* add_mask_ele_add_node = pattern->NewNode(add_mask_ele_add_op_repr())
->assert_is_op("elementwise_add");
auto* add_mask_ele_add_mask_node =
pattern->NewNode(add_mask_ele_add_mask_repr())
->assert_is_op_input("elementwise_add", "Y");
add_mask_ele_add_out_node = pattern->NewNode(add_mask_ele_add_out_repr())
->assert_is_op_output("elementwise_add");
qk_scale_out_node->assert_is_op_input("elementwise_add", "X");
add_mask_ele_add_node
->LinksFrom({qk_scale_out_node, add_mask_ele_add_mask_node})
.LinksTo({add_mask_ele_add_out_node});
}
auto* qk_softmax_node =
pattern->NewNode(qk_softmax_op_repr())->assert_is_op("softmax");
auto* qk_softmax_out_node =
pattern->NewNode(qk_softmax_out_repr())->assert_is_op_output("softmax");
if (has_attn_mask) {
add_mask_ele_add_out_node->assert_is_op_input("softmax", "X");
qk_softmax_node->LinksFrom({add_mask_ele_add_out_node})
.LinksTo({qk_softmax_out_node});
} else {
qk_scale_out_node->assert_is_op_input("softmax", "X");
qk_softmax_node->LinksFrom({qk_scale_out_node})
.LinksTo({qk_softmax_out_node});
}
PDNode* attn_dropout_out_node{nullptr};
if (do_dropout) {
auto* attn_dropout_node =
pattern->NewNode(attn_dropout_op_repr())->assert_is_op("dropout");
auto* attn_dropout_mask_node = pattern->NewNode(attn_dropout_mask_repr())
->assert_is_op_output("dropout", "Mask");
attn_dropout_out_node = pattern->NewNode(attn_dropout_out_repr())
->assert_is_op_output("dropout");
qk_softmax_out_node->assert_is_op_input("dropout", "X");
attn_dropout_node->LinksFrom({qk_softmax_out_node})
.LinksTo({attn_dropout_mask_node, attn_dropout_out_node});
}
auto* qkv_matmul_node =
pattern->NewNode(qkv_matmul_op_repr())->assert_is_op("matmul_v2");
auto* qkv_matmul_out_node =
pattern->NewNode(qkv_matmul_out_repr())->assert_is_op_output("matmul_v2");
fuse_qkv_split_out_v_node->assert_is_op_input("matmul_v2", "Y");
if (do_dropout) {
attn_dropout_out_node->assert_is_op_input("matmul_v2", "X");
qkv_matmul_node
->LinksFrom({attn_dropout_out_node, fuse_qkv_split_out_v_node})
.LinksTo({qkv_matmul_out_node});
} else {
qk_softmax_out_node->assert_is_op_input("matmul_v2", "X");
qkv_matmul_node->LinksFrom({qk_softmax_out_node, fuse_qkv_split_out_v_node})
.LinksTo({qkv_matmul_out_node});
}
auto* qkv_transpose_node =
pattern->NewNode(qkv_transpose_op_repr())->assert_is_op("transpose2");
auto* qkv_transpose_x_shape_node =
pattern->NewNode(qkv_transpose_x_shape_repr())
->assert_is_op_output("transpose2", "XShape");
auto* qkv_transpose_out_node = pattern->NewNode(qkv_transpose_out_repr())
->assert_is_op_output("transpose2");
qkv_matmul_out_node->assert_is_op_input("transpose2", "X");
qkv_transpose_node->LinksFrom({qkv_matmul_out_node})
.LinksTo({qkv_transpose_x_shape_node, qkv_transpose_out_node});
auto* qkv_reshape_node =
pattern->NewNode(qkv_reshape_op_repr())->assert_is_op("reshape2");
auto* qkv_reshape_x_shape_node =
pattern->NewNode(qkv_reshape_x_shape_repr())
->assert_is_op_output("reshape2", "XShape");
auto* qkv_reshape_out_node =
pattern->NewNode(qkv_reshape_out_repr())->assert_is_op_output("reshape2");
qkv_transpose_out_node->assert_is_op_input("reshape2", "X");
qkv_reshape_node->LinksFrom({qkv_transpose_out_node})
.LinksTo({qkv_reshape_x_shape_node, qkv_reshape_out_node});
// out linear pattern
auto* out_linear_matmul_node =
pattern->NewNode(out_linear_matmul_op_repr())->assert_is_op("matmul_v2");
auto* out_linear_matmul_w_node = pattern->NewNode(out_linear_matmul_w_repr())
->assert_is_op_input("matmul_v2", "Y");
auto* out_linear_matmul_out_node =
pattern->NewNode(out_linear_matmul_out_repr())
->assert_is_op_output("matmul_v2");
qkv_reshape_out_node->assert_is_op_input("matmul_v2", "X");
out_linear_matmul_node
->LinksFrom({qkv_reshape_out_node, out_linear_matmul_w_node})
.LinksTo({out_linear_matmul_out_node});
auto* out_linear_ele_add_node = pattern->NewNode(out_linear_ele_add_op_repr())
->assert_is_op("elementwise_add");
auto* out_linear_ele_add_bias_node =
pattern->NewNode(out_linear_ele_add_bias_repr())
->assert_is_op_input("elementwise_add", "Y");
auto* out_linear_ele_add_out_node =
pattern->NewNode(out_linear_ele_add_out_repr())
->assert_is_op_output("elementwise_add");
out_linear_matmul_out_node->assert_is_op_input("elementwise_add", "X");
out_linear_ele_add_node
->LinksFrom({out_linear_matmul_out_node, out_linear_ele_add_bias_node})
.LinksTo({out_linear_ele_add_out_node});
auto* out_linear_dropout_node =
pattern->NewNode(out_linear_dropout_op_repr())->assert_is_op("dropout");
auto* out_linear_dropout_mask_node =
pattern->NewNode(out_linear_dropout_mask_repr())
->assert_is_op_output("dropout", "Mask");
auto* out_linear_dropout_out_node =
pattern->NewNode(out_linear_dropout_out_repr())
->assert_is_op_output("dropout");
out_linear_ele_add_out_node->assert_is_op_input("dropout", "X");
out_linear_dropout_node->LinksFrom({out_linear_ele_add_out_node})
.LinksTo({out_linear_dropout_mask_node, out_linear_dropout_out_node});
if (!add_residual && !post_layer_norm) {
return out_linear_dropout_out_node;
}
// add residual
PDNode* residual_ele_add_out_node{nullptr};
if (add_residual) {
// this kind of pattern only support `residual + dropout_out`, since we have
// to fix X and Y
auto* residual_ele_add_node = pattern->NewNode(residual_ele_add_op_repr())
->assert_is_op("elementwise_add");
residual_ele_add_out_node = pattern->NewNode(residual_ele_add_out_repr())
->assert_is_op_output("elementwise_add");
out_linear_dropout_out_node->assert_is_op_input("elementwise_add", "Y");
residual_ele_add_node->LinksFrom({x, out_linear_dropout_out_node})
.LinksTo({residual_ele_add_out_node});
if (!post_layer_norm) {
return residual_ele_add_out_node;
}
}
// post layer norm
auto* post_layer_norm_node =
pattern->NewNode(post_layer_norm_op_repr())->assert_is_op("layer_norm");
auto* post_layer_norm_scale_node =
pattern->NewNode(post_layer_norm_scale_repr())
->assert_is_op_input("layer_norm", "Scale");
auto* post_layer_norm_bias_node =
pattern->NewNode(post_layer_norm_bias_repr())
->assert_is_op_input("layer_norm", "Bias");
auto* post_layer_norm_out_node = pattern->NewNode(post_layer_norm_out_repr())
->assert_is_op_output("layer_norm", "Y");
auto* post_layer_norm_mean_node =
pattern->NewNode(post_layer_norm_mean_repr())
->assert_is_op_output("layer_norm", "Mean");
auto* post_layer_norm_variance_node =
pattern->NewNode(post_layer_norm_variance_repr())
->assert_is_op_output("layer_norm", "Variance");
if (add_residual) {
residual_ele_add_out_node->assert_is_op_input("layer_norm", "X");
post_layer_norm_node
->LinksFrom({residual_ele_add_out_node,
post_layer_norm_scale_node,
post_layer_norm_bias_node})
.LinksTo({post_layer_norm_out_node,
post_layer_norm_mean_node,
post_layer_norm_variance_node});
} else {
out_linear_dropout_out_node->assert_is_op_input("layer_norm", "X");
post_layer_norm_node
->LinksFrom({out_linear_dropout_out_node,
post_layer_norm_scale_node,
post_layer_norm_bias_node})
.LinksTo({post_layer_norm_out_node,
post_layer_norm_mean_node,
post_layer_norm_variance_node});
}
return post_layer_norm_out_node;
}
PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
bool pre_layer_norm,
bool post_layer_norm,
bool has_attn_mask,
bool do_dropout,
bool add_residual) {
// TODO(Yuang Liu): finish the backward pattern
return nullptr;
}
} // namespace patterns
void FusedAttentionsPass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
graph = PreMaskDropResPostFwd(graph);
graph = PreMaskDropResPostBwd(graph);
}
ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "x"))
->AsInput()
->assert_is_op_input("layer_norm", "X");
patterns::FusedAttentionPattern fused_attention_pattern(
gpd.mutable_pattern(), "fused_attention_pattern");
fused_attention_pattern(x,
/* pre_layer_norm */ true,
/* post_layer_norm */ true,
/* has_attn_mask */ true,
/* do_dropout */ true,
/* add_residual */ true);
int found_fused_attention = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(3) << "handle FusedMultiHeadAttention pass's fusion";
GET_IR_NODE_FROM_SUBGRAPH(
pre_layer_norm_op_node, pre_layer_norm_op, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
fuse_qkv_matmul_op_node, fuse_qkv_matmul_op, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
fuse_qkv_ele_add_op_node, fuse_qkv_ele_add_op, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
fuse_qkv_reshape_op_node, fuse_qkv_reshape_op, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_transpose_op_node,
fuse_qkv_transpose_op,
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
fuse_qkv_split_op_node, fuse_qkv_split_op, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
qk_matmul_op_node, qk_matmul_op, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
qk_scale_op_node, qk_scale_op, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
add_mask_ele_add_op_node, add_mask_ele_add_op, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
qk_softmax_op_node, qk_softmax_op, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
attn_dropout_op_node, attn_dropout_op, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
qkv_matmul_op_node, qkv_matmul_op, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
qkv_transpose_op_node, qkv_transpose_op, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
qkv_reshape_op_node, qkv_reshape_op, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(out_linear_matmul_op_node,
out_linear_matmul_op,
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_op_node,
out_linear_ele_add_op,
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(out_linear_dropout_op_node,
out_linear_dropout_op,
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
residual_ele_add_op_node, residual_ele_add_op, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
post_layer_norm_op_node, post_layer_norm_op, fused_attention_pattern);
// TODO(Yuang Liu): finish the handler
GraphSafeRemoveNodes(g,
{pre_layer_norm_op_node,
fuse_qkv_matmul_op_node,
fuse_qkv_ele_add_op_node,
fuse_qkv_reshape_op_node,
fuse_qkv_transpose_op_node,
fuse_qkv_split_op_node,
qk_matmul_op_node,
qk_scale_op_node,
add_mask_ele_add_op_node,
qk_softmax_op_node,
attn_dropout_op_node,
qkv_matmul_op_node,
qkv_transpose_op_node,
qkv_reshape_op_node,
out_linear_matmul_op_node,
out_linear_ele_add_op_node,
out_linear_dropout_op_node,
residual_ele_add_op_node,
post_layer_norm_op_node});
found_fused_attention++;
};
gpd(graph, handler);
AddStatis(found_fused_attention);
return graph;
}
ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const {
// TODO(Yuang Liu): finish the pass
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fused_attention_pass, paddle::framework::ir::FusedAttentionsPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
// Declare patterns for multi head attention.
// Can detect:
// 1. Pre layer norm, post layer norm or sandwich layer norm.
// 2. Add attn mask for qk product before the softmax or not.
// 3. Do attn dropout or not.
// 4. Add residual to the out linear result or not.
struct FusedAttentionPattern : public PatternBase {
FusedAttentionPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fused_attention_pattern") {}
PDNode* operator()(PDNode* x,
bool pre_layer_norm, // do pre ln or not
bool post_layer_norm, // do post ln or not
bool has_attn_mask, // add attn mask to qk or not
bool do_dropout, // dropout the softmax(qk) or not
bool add_residual); // add residual to out linear or not
// pre layer norm
PATTERN_DECL_NODE(pre_layer_norm_op);
PATTERN_DECL_NODE(pre_layer_norm_scale);
PATTERN_DECL_NODE(pre_layer_norm_bias);
PATTERN_DECL_NODE(pre_layer_norm_out);
PATTERN_DECL_NODE(pre_layer_norm_mean);
PATTERN_DECL_NODE(pre_layer_norm_variance);
// fuse qkv projection
PATTERN_DECL_NODE(fuse_qkv_matmul_op);
PATTERN_DECL_NODE(fuse_qkv_matmul_w);
PATTERN_DECL_NODE(fuse_qkv_matmul_out);
PATTERN_DECL_NODE(fuse_qkv_ele_add_op);
PATTERN_DECL_NODE(fuse_qkv_ele_add_bias);
PATTERN_DECL_NODE(fuse_qkv_ele_add_out);
PATTERN_DECL_NODE(fuse_qkv_reshape_op);
PATTERN_DECL_NODE(fuse_qkv_reshape_out);
PATTERN_DECL_NODE(fuse_qkv_reshape_x_shape);
PATTERN_DECL_NODE(fuse_qkv_transpose_op);
PATTERN_DECL_NODE(fuse_qkv_transpose_out);
PATTERN_DECL_NODE(fuse_qkv_transpose_x_shape);
PATTERN_DECL_NODE(fuse_qkv_split_op);
PATTERN_DECL_NODE(fuse_qkv_split_out_q); // q
PATTERN_DECL_NODE(fuse_qkv_split_out_k); // k
PATTERN_DECL_NODE(fuse_qkv_split_out_v); // v
// core attention
PATTERN_DECL_NODE(qk_matmul_op);
PATTERN_DECL_NODE(qk_matmul_out);
PATTERN_DECL_NODE(qk_scale_op);
PATTERN_DECL_NODE(qk_scale_out);
PATTERN_DECL_NODE(add_mask_ele_add_op);
PATTERN_DECL_NODE(add_mask_ele_add_mask);
PATTERN_DECL_NODE(add_mask_ele_add_out);
PATTERN_DECL_NODE(qk_softmax_op);
PATTERN_DECL_NODE(qk_softmax_out);
PATTERN_DECL_NODE(attn_dropout_op);
PATTERN_DECL_NODE(attn_dropout_out);
PATTERN_DECL_NODE(attn_dropout_mask);
PATTERN_DECL_NODE(qkv_matmul_op);
PATTERN_DECL_NODE(qkv_matmul_out);
PATTERN_DECL_NODE(qkv_transpose_op);
PATTERN_DECL_NODE(qkv_transpose_out);
PATTERN_DECL_NODE(qkv_transpose_x_shape);
PATTERN_DECL_NODE(qkv_reshape_op);
PATTERN_DECL_NODE(qkv_reshape_out);
PATTERN_DECL_NODE(qkv_reshape_x_shape);
// out linear
PATTERN_DECL_NODE(out_linear_matmul_op);
PATTERN_DECL_NODE(out_linear_matmul_w);
PATTERN_DECL_NODE(out_linear_matmul_out);
PATTERN_DECL_NODE(out_linear_ele_add_op);
PATTERN_DECL_NODE(out_linear_ele_add_bias);
PATTERN_DECL_NODE(out_linear_ele_add_out);
PATTERN_DECL_NODE(out_linear_dropout_op);
PATTERN_DECL_NODE(out_linear_dropout_out);
PATTERN_DECL_NODE(out_linear_dropout_mask);
// residual
PATTERN_DECL_NODE(residual_ele_add_op);
PATTERN_DECL_NODE(residual_ele_add_out);
// post layer norm
PATTERN_DECL_NODE(post_layer_norm_op);
PATTERN_DECL_NODE(post_layer_norm_scale);
PATTERN_DECL_NODE(post_layer_norm_bias);
PATTERN_DECL_NODE(post_layer_norm_out);
PATTERN_DECL_NODE(post_layer_norm_mean);
PATTERN_DECL_NODE(post_layer_norm_variance);
};
// Declare the grad pattern for multi head attention
struct FusedAttentionGradPattern : public PatternBase {
FusedAttentionGradPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fused_attention_pattern") {}
PDNode* operator()(PDNode* x,
bool pre_layer_norm, // pre ln
bool post_layer_norm, // post ln
bool has_attn_mask, // add attn mask to qk or not
bool do_dropout, // dropout the softmax(qk) or not
bool add_residual); // add residual to out linear or not
// TODO(Yuang Liu): add backward pattern
};
} // namespace patterns
class FusedAttentionsPass : public FusePassBase {
public:
virtual ~FusedAttentionsPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"fused_attention_pass"};
private:
// The name rule for the helper function.
// The function name will contain at most five parts in order:
// 1. Do pre layer norm? [Pre]
// 2. Add mask in the core attention part? [Mask]
// 3. Do dropout in the core attention part? [Drop]
// 4. Add residual? [Res]
// 5. Do post layer norm? [Post]
// 6. Forward or Backward? [Fwd/Bwd]
// If true, the function name will have an abbreviation part.
// If false, the function name won't contain an abbreviation for it.
ir::Graph* PreMaskDropResPostFwd(Graph* graph) const;
ir::Graph* PreMaskDropResPostBwd(Graph* graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -714,6 +714,32 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT ...@@ -714,6 +714,32 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT
build_strategy = static.BuildStrategy() build_strategy = static.BuildStrategy()
build_strategy.fuse_gemm_epilogue = True build_strategy.fuse_gemm_epilogue = True
)DOC") )DOC")
.def_property(
"fused_attention",
[](const BuildStrategy &self) { return self.fused_attention_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_NE(self.IsFinalized(),
true,
platform::errors::PreconditionNotMet(
"BuildStrategy has been finlaized, cannot be "
"configured again."));
self.fused_attention_ = b;
},
R"DOC((bool, optional): fused_attention indicate whether
to fuse the whole multi head attention part with one op,
it may make the execution faster. Default is False.
Examples:
.. code-block:: python
import paddle
import paddle.static as static
paddle.enable_static()
build_strategy = static.BuildStrategy()
build_strategy.fused_attention = True
)DOC")
.def_property( .def_property(
"fuse_bn_act_ops", "fuse_bn_act_ops",
[](const BuildStrategy &self) { return self.fuse_bn_act_ops_; }, [](const BuildStrategy &self) { return self.fuse_bn_act_ops_; },
......
...@@ -71,6 +71,19 @@ class FuseReluDepthwiseConvPass(CPPPassWrapper): ...@@ -71,6 +71,19 @@ class FuseReluDepthwiseConvPass(CPPPassWrapper):
return PassType.FUSION_OPT return PassType.FUSION_OPT
@register_pass("fused_attention")
class FusedAttentionPass(CPPPassWrapper):
def __init__(self):
super().__init__()
@property
def cpp_name(self):
return "fused_attention_pass"
def _type(self):
return PassType.FUSION_OPT
@register_pass("fuse_gemm_epilogue") @register_pass("fuse_gemm_epilogue")
class FuseGemmEpiloguePass(CPPPassWrapper): class FuseGemmEpiloguePass(CPPPassWrapper):
def __init__(self): def __init__(self):
......
...@@ -76,6 +76,7 @@ if(NOT WITH_GPU) ...@@ -76,6 +76,7 @@ if(NOT WITH_GPU)
list(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer) list(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api)
list(REMOVE_ITEM TEST_OPS test_fused_attention_pass)
endif() endif()
list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.nn.functional as F
from paddle.distributed.passes import PassManager, new_pass
paddle.enable_static()
class MultiHeadAttention(paddle.nn.Layer):
def __init__(
self,
embed_dim,
num_heads,
add_residual=True,
pre_ln=True,
post_ln=False,
attn_dropout=True,
):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = embed_dim
self.vdim = embed_dim
self.num_heads = num_heads
self.add_residual = add_residual
self.pre_ln = pre_ln
self.post_ln = post_ln
self.attn_dropout = attn_dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.norm1 = paddle.nn.LayerNorm(embed_dim, epsilon=1e-5)
self.norm2 = paddle.nn.LayerNorm(embed_dim, epsilon=1e-5)
self.qkv_proj = paddle.nn.Linear(embed_dim, 3 * embed_dim)
self.out_proj = paddle.nn.Linear(embed_dim, embed_dim)
self.dropout = paddle.nn.Dropout(0.1, mode="upscale_in_train")
def forward(self, x, attn_mask=None):
residual = x
if self.pre_ln:
# pre layer norm
x = self.norm1(x)
# compute qkv
qkv = self.qkv_proj(x)
qkv = paddle.reshape(qkv, [0, 0, self.num_heads, 3 * self.head_dim])
qkv = paddle.transpose(qkv, [0, 2, 1, 3])
q, k, v = paddle.split(qkv, num_or_sections=3, axis=-1)
# compute core attention
product = paddle.matmul(x=q, y=k, transpose_y=True)
product = paddle.scale(product, scale=self.head_dim**-0.5)
if attn_mask is not None:
product = product + attn_mask
weights = F.softmax(product)
if self.attn_dropout:
weights = F.dropout(
weights, 0.1, training=self.training, mode="upscale_in_train"
)
out = paddle.matmul(weights, v)
out = paddle.transpose(out, perm=[0, 2, 1, 3])
out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.out_proj(out)
out = self.dropout(out)
if self.add_residual:
out = residual + out
if self.post_ln:
# post layer norm
out = self.norm2(out)
return out
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestFusedAttentionPass(unittest.TestCase):
def setUp(self):
self.add_residual = True
self.pre_ln = True
self.post_ln = True
self.attn_dropout = True
self.add_mask = True
def test_pass(self):
batch_size = 2
seq_len = 1024
hidden_size = 768
num_heads = 12
x_data = np.random.rand(batch_size, seq_len, hidden_size).astype(
'float32'
)
mask_data = np.random.rand(
batch_size, num_heads, seq_len, seq_len
).astype('float32')
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, startup_prog):
data = paddle.static.data(
name="x",
shape=[-1, seq_len, hidden_size],
dtype='float32',
)
if self.add_mask:
attn_mask = paddle.static.data(
name="attn_mask",
shape=[-1, num_heads, seq_len, seq_len],
dtype='float32',
)
else:
attn_mask = None
multi_head_attn = MultiHeadAttention(
hidden_size,
num_heads,
add_residual=self.add_residual,
pre_ln=self.pre_ln,
post_ln=self.post_ln,
attn_dropout=self.attn_dropout,
)
out = multi_head_attn(data, attn_mask)
loss = paddle.mean(out)
sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(loss)
pass_manager = PassManager([new_pass("fused_attention")])
pass_manager.apply([main_prog], [startup_prog])
ops = main_prog.global_block().ops
assert ops[0].type == 'reduce_mean'
if __name__ == "__main__":
np.random.seed(0)
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册