未验证 提交 b0ece266 编写于 作者: Y Yuang Liu 提交者: GitHub

[Fuse attention pass] Forward pattern. (#49621)

上级 13992de7
...@@ -382,6 +382,7 @@ set(IR_PASS_DEPS ...@@ -382,6 +382,7 @@ set(IR_PASS_DEPS
graph_to_program_pass graph_to_program_pass
fix_op_run_order_pass fix_op_run_order_pass
fuse_gemm_epilogue_pass fuse_gemm_epilogue_pass
fused_attention_pass
delete_dropout_op_pass) delete_dropout_op_pass)
if(WITH_CINN) if(WITH_CINN)
......
...@@ -187,6 +187,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -187,6 +187,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass"); AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass");
#endif #endif
#ifdef PADDLE_WITH_CUDA
AppendPassWithCheck(strategy_.fused_attention_, "fused_attention_pass");
#endif
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) #if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060)
AppendPassWithCheck(strategy_.fuse_gemm_epilogue_, AppendPassWithCheck(strategy_.fuse_gemm_epilogue_,
"fuse_gemm_epilogue_pass"); "fuse_gemm_epilogue_pass");
...@@ -519,6 +523,9 @@ USE_PASS(fuse_all_reduce_op_pass); ...@@ -519,6 +523,9 @@ USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass); USE_PASS(runtime_context_cache_pass);
USE_PASS(add_reader_dependency_pass); USE_PASS(add_reader_dependency_pass);
USE_PASS(delete_dropout_op_x_pass); USE_PASS(delete_dropout_op_x_pass);
#ifdef PADDLE_WITH_CUDA
USE_PASS(fused_attention_pass);
#endif
#ifdef PADDLE_WITH_CINN #ifdef PADDLE_WITH_CINN
USE_PASS(build_cinn_pass); USE_PASS(build_cinn_pass);
#endif #endif
......
...@@ -129,6 +129,8 @@ struct BuildStrategy { ...@@ -129,6 +129,8 @@ struct BuildStrategy {
bool sync_batch_norm_{false}; bool sync_batch_norm_{false};
// Fuse GEMM+Epilogue via cublasLt epilogue. // Fuse GEMM+Epilogue via cublasLt epilogue.
bool fuse_gemm_epilogue_{false}; bool fuse_gemm_epilogue_{false};
// Fused multi head attention
bool fused_attention_{false};
// mkldnn_enabled_op_types specify the operator type list to // mkldnn_enabled_op_types specify the operator type list to
// use MKLDNN acceleration. It is null in default, means // use MKLDNN acceleration. It is null in default, means
...@@ -261,6 +263,7 @@ inline std::ostream &operator<<(std::ostream &os, ...@@ -261,6 +263,7 @@ inline std::ostream &operator<<(std::ostream &os,
os << "fuse_broadcast_ops_: " << strategy.fuse_broadcast_ops_ << std::endl; os << "fuse_broadcast_ops_: " << strategy.fuse_broadcast_ops_ << std::endl;
os << "sync_batch_norm_: " << strategy.sync_batch_norm_ << std::endl; os << "sync_batch_norm_: " << strategy.sync_batch_norm_ << std::endl;
os << "fuse_gemm_epilogue_: " << strategy.fuse_gemm_epilogue_ << std::endl; os << "fuse_gemm_epilogue_: " << strategy.fuse_gemm_epilogue_ << std::endl;
os << "fused_attention_: " << strategy.fused_attention_ << std::endl;
os << "mkldnn_enabled_op_types_: "; os << "mkldnn_enabled_op_types_: ";
for (auto str : strategy.mkldnn_enabled_op_types_) { for (auto str : strategy.mkldnn_enabled_op_types_) {
os << str << ", "; os << str << ", ";
......
...@@ -124,6 +124,7 @@ message BuildStrategy { ...@@ -124,6 +124,7 @@ message BuildStrategy {
optional int32 reduce_strategy = 15 [ default = 0 ]; optional int32 reduce_strategy = 15 [ default = 0 ];
optional bool fuse_gemm_epilogue = 16 [ default = false ]; optional bool fuse_gemm_epilogue = 16 [ default = false ];
optional string debug_graphviz_path = 17; optional string debug_graphviz_path = 17;
optional bool fused_attention = 18 [ default = false];
} }
message ExecutionStrategy { message ExecutionStrategy {
......
...@@ -223,6 +223,10 @@ cc_library( ...@@ -223,6 +223,10 @@ cc_library(
fuse_gemm_epilogue_pass fuse_gemm_epilogue_pass
SRCS fuse_gemm_epilogue_pass.cc SRCS fuse_gemm_epilogue_pass.cc
DEPS pass graph_pattern_detector) DEPS pass graph_pattern_detector)
cc_library(
fused_attention_pass
SRCS fused_attention_pass.cc
DEPS pass graph_pattern_detector)
cc_library( cc_library(
fuse_relu_depthwise_conv_pass fuse_relu_depthwise_conv_pass
SRCS fuse_relu_depthwise_conv_pass.cc SRCS fuse_relu_depthwise_conv_pass.cc
......
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
// Declare patterns for multi head attention.
// Can detect:
// 1. Pre layer norm, post layer norm or sandwich layer norm.
// 2. Add attn mask for qk product before the softmax or not.
// 3. Do attn dropout or not.
// 4. Add residual to the out linear result or not.
struct FusedAttentionPattern : public PatternBase {
FusedAttentionPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fused_attention_pattern") {}
PDNode* operator()(PDNode* x,
bool pre_layer_norm, // do pre ln or not
bool post_layer_norm, // do post ln or not
bool has_attn_mask, // add attn mask to qk or not
bool do_dropout, // dropout the softmax(qk) or not
bool add_residual); // add residual to out linear or not
// pre layer norm
PATTERN_DECL_NODE(pre_layer_norm_op);
PATTERN_DECL_NODE(pre_layer_norm_scale);
PATTERN_DECL_NODE(pre_layer_norm_bias);
PATTERN_DECL_NODE(pre_layer_norm_out);
PATTERN_DECL_NODE(pre_layer_norm_mean);
PATTERN_DECL_NODE(pre_layer_norm_variance);
// fuse qkv projection
PATTERN_DECL_NODE(fuse_qkv_matmul_op);
PATTERN_DECL_NODE(fuse_qkv_matmul_w);
PATTERN_DECL_NODE(fuse_qkv_matmul_out);
PATTERN_DECL_NODE(fuse_qkv_ele_add_op);
PATTERN_DECL_NODE(fuse_qkv_ele_add_bias);
PATTERN_DECL_NODE(fuse_qkv_ele_add_out);
PATTERN_DECL_NODE(fuse_qkv_reshape_op);
PATTERN_DECL_NODE(fuse_qkv_reshape_out);
PATTERN_DECL_NODE(fuse_qkv_reshape_x_shape);
PATTERN_DECL_NODE(fuse_qkv_transpose_op);
PATTERN_DECL_NODE(fuse_qkv_transpose_out);
PATTERN_DECL_NODE(fuse_qkv_transpose_x_shape);
PATTERN_DECL_NODE(fuse_qkv_split_op);
PATTERN_DECL_NODE(fuse_qkv_split_out_q); // q
PATTERN_DECL_NODE(fuse_qkv_split_out_k); // k
PATTERN_DECL_NODE(fuse_qkv_split_out_v); // v
// core attention
PATTERN_DECL_NODE(qk_matmul_op);
PATTERN_DECL_NODE(qk_matmul_out);
PATTERN_DECL_NODE(qk_scale_op);
PATTERN_DECL_NODE(qk_scale_out);
PATTERN_DECL_NODE(add_mask_ele_add_op);
PATTERN_DECL_NODE(add_mask_ele_add_mask);
PATTERN_DECL_NODE(add_mask_ele_add_out);
PATTERN_DECL_NODE(qk_softmax_op);
PATTERN_DECL_NODE(qk_softmax_out);
PATTERN_DECL_NODE(attn_dropout_op);
PATTERN_DECL_NODE(attn_dropout_out);
PATTERN_DECL_NODE(attn_dropout_mask);
PATTERN_DECL_NODE(qkv_matmul_op);
PATTERN_DECL_NODE(qkv_matmul_out);
PATTERN_DECL_NODE(qkv_transpose_op);
PATTERN_DECL_NODE(qkv_transpose_out);
PATTERN_DECL_NODE(qkv_transpose_x_shape);
PATTERN_DECL_NODE(qkv_reshape_op);
PATTERN_DECL_NODE(qkv_reshape_out);
PATTERN_DECL_NODE(qkv_reshape_x_shape);
// out linear
PATTERN_DECL_NODE(out_linear_matmul_op);
PATTERN_DECL_NODE(out_linear_matmul_w);
PATTERN_DECL_NODE(out_linear_matmul_out);
PATTERN_DECL_NODE(out_linear_ele_add_op);
PATTERN_DECL_NODE(out_linear_ele_add_bias);
PATTERN_DECL_NODE(out_linear_ele_add_out);
PATTERN_DECL_NODE(out_linear_dropout_op);
PATTERN_DECL_NODE(out_linear_dropout_out);
PATTERN_DECL_NODE(out_linear_dropout_mask);
// residual
PATTERN_DECL_NODE(residual_ele_add_op);
PATTERN_DECL_NODE(residual_ele_add_out);
// post layer norm
PATTERN_DECL_NODE(post_layer_norm_op);
PATTERN_DECL_NODE(post_layer_norm_scale);
PATTERN_DECL_NODE(post_layer_norm_bias);
PATTERN_DECL_NODE(post_layer_norm_out);
PATTERN_DECL_NODE(post_layer_norm_mean);
PATTERN_DECL_NODE(post_layer_norm_variance);
};
// Declare the grad pattern for multi head attention
struct FusedAttentionGradPattern : public PatternBase {
FusedAttentionGradPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fused_attention_pattern") {}
PDNode* operator()(PDNode* x,
bool pre_layer_norm, // pre ln
bool post_layer_norm, // post ln
bool has_attn_mask, // add attn mask to qk or not
bool do_dropout, // dropout the softmax(qk) or not
bool add_residual); // add residual to out linear or not
// TODO(Yuang Liu): add backward pattern
};
} // namespace patterns
class FusedAttentionsPass : public FusePassBase {
public:
virtual ~FusedAttentionsPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"fused_attention_pass"};
private:
// The name rule for the helper function.
// The function name will contain at most five parts in order:
// 1. Do pre layer norm? [Pre]
// 2. Add mask in the core attention part? [Mask]
// 3. Do dropout in the core attention part? [Drop]
// 4. Add residual? [Res]
// 5. Do post layer norm? [Post]
// 6. Forward or Backward? [Fwd/Bwd]
// If true, the function name will have an abbreviation part.
// If false, the function name won't contain an abbreviation for it.
ir::Graph* PreMaskDropResPostFwd(Graph* graph) const;
ir::Graph* PreMaskDropResPostBwd(Graph* graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -714,6 +714,32 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT ...@@ -714,6 +714,32 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT
build_strategy = static.BuildStrategy() build_strategy = static.BuildStrategy()
build_strategy.fuse_gemm_epilogue = True build_strategy.fuse_gemm_epilogue = True
)DOC") )DOC")
.def_property(
"fused_attention",
[](const BuildStrategy &self) { return self.fused_attention_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_NE(self.IsFinalized(),
true,
platform::errors::PreconditionNotMet(
"BuildStrategy has been finlaized, cannot be "
"configured again."));
self.fused_attention_ = b;
},
R"DOC((bool, optional): fused_attention indicate whether
to fuse the whole multi head attention part with one op,
it may make the execution faster. Default is False.
Examples:
.. code-block:: python
import paddle
import paddle.static as static
paddle.enable_static()
build_strategy = static.BuildStrategy()
build_strategy.fused_attention = True
)DOC")
.def_property( .def_property(
"fuse_bn_act_ops", "fuse_bn_act_ops",
[](const BuildStrategy &self) { return self.fuse_bn_act_ops_; }, [](const BuildStrategy &self) { return self.fuse_bn_act_ops_; },
......
...@@ -71,6 +71,19 @@ class FuseReluDepthwiseConvPass(CPPPassWrapper): ...@@ -71,6 +71,19 @@ class FuseReluDepthwiseConvPass(CPPPassWrapper):
return PassType.FUSION_OPT return PassType.FUSION_OPT
@register_pass("fused_attention")
class FusedAttentionPass(CPPPassWrapper):
def __init__(self):
super().__init__()
@property
def cpp_name(self):
return "fused_attention_pass"
def _type(self):
return PassType.FUSION_OPT
@register_pass("fuse_gemm_epilogue") @register_pass("fuse_gemm_epilogue")
class FuseGemmEpiloguePass(CPPPassWrapper): class FuseGemmEpiloguePass(CPPPassWrapper):
def __init__(self): def __init__(self):
......
...@@ -76,6 +76,7 @@ if(NOT WITH_GPU) ...@@ -76,6 +76,7 @@ if(NOT WITH_GPU)
list(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer) list(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api)
list(REMOVE_ITEM TEST_OPS test_fused_attention_pass)
endif() endif()
list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.nn.functional as F
from paddle.distributed.passes import PassManager, new_pass
paddle.enable_static()
class MultiHeadAttention(paddle.nn.Layer):
def __init__(
self,
embed_dim,
num_heads,
add_residual=True,
pre_ln=True,
post_ln=False,
attn_dropout=True,
):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = embed_dim
self.vdim = embed_dim
self.num_heads = num_heads
self.add_residual = add_residual
self.pre_ln = pre_ln
self.post_ln = post_ln
self.attn_dropout = attn_dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.norm1 = paddle.nn.LayerNorm(embed_dim, epsilon=1e-5)
self.norm2 = paddle.nn.LayerNorm(embed_dim, epsilon=1e-5)
self.qkv_proj = paddle.nn.Linear(embed_dim, 3 * embed_dim)
self.out_proj = paddle.nn.Linear(embed_dim, embed_dim)
self.dropout = paddle.nn.Dropout(0.1, mode="upscale_in_train")
def forward(self, x, attn_mask=None):
residual = x
if self.pre_ln:
# pre layer norm
x = self.norm1(x)
# compute qkv
qkv = self.qkv_proj(x)
qkv = paddle.reshape(qkv, [0, 0, self.num_heads, 3 * self.head_dim])
qkv = paddle.transpose(qkv, [0, 2, 1, 3])
q, k, v = paddle.split(qkv, num_or_sections=3, axis=-1)
# compute core attention
product = paddle.matmul(x=q, y=k, transpose_y=True)
product = paddle.scale(product, scale=self.head_dim**-0.5)
if attn_mask is not None:
product = product + attn_mask
weights = F.softmax(product)
if self.attn_dropout:
weights = F.dropout(
weights, 0.1, training=self.training, mode="upscale_in_train"
)
out = paddle.matmul(weights, v)
out = paddle.transpose(out, perm=[0, 2, 1, 3])
out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.out_proj(out)
out = self.dropout(out)
if self.add_residual:
out = residual + out
if self.post_ln:
# post layer norm
out = self.norm2(out)
return out
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestFusedAttentionPass(unittest.TestCase):
def setUp(self):
self.add_residual = True
self.pre_ln = True
self.post_ln = True
self.attn_dropout = True
self.add_mask = True
def test_pass(self):
batch_size = 2
seq_len = 1024
hidden_size = 768
num_heads = 12
x_data = np.random.rand(batch_size, seq_len, hidden_size).astype(
'float32'
)
mask_data = np.random.rand(
batch_size, num_heads, seq_len, seq_len
).astype('float32')
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, startup_prog):
data = paddle.static.data(
name="x",
shape=[-1, seq_len, hidden_size],
dtype='float32',
)
if self.add_mask:
attn_mask = paddle.static.data(
name="attn_mask",
shape=[-1, num_heads, seq_len, seq_len],
dtype='float32',
)
else:
attn_mask = None
multi_head_attn = MultiHeadAttention(
hidden_size,
num_heads,
add_residual=self.add_residual,
pre_ln=self.pre_ln,
post_ln=self.post_ln,
attn_dropout=self.attn_dropout,
)
out = multi_head_attn(data, attn_mask)
loss = paddle.mean(out)
sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(loss)
pass_manager = PassManager([new_pass("fused_attention")])
pass_manager.apply([main_prog], [startup_prog])
ops = main_prog.global_block().ops
assert ops[0].type == 'reduce_mean'
if __name__ == "__main__":
np.random.seed(0)
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册