Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b0ece266
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b0ece266
编写于
1月 10, 2023
作者:
Y
Yuang Liu
提交者:
GitHub
1月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Fuse attention pass] Forward pattern. (#49621)
上级
13992de7
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
844 addition
and
0 deletion
+844
-0
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+1
-0
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+7
-0
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+3
-0
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+1
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+4
-0
paddle/fluid/framework/ir/fused_attention_pass.cc
paddle/fluid/framework/ir/fused_attention_pass.cc
+448
-0
paddle/fluid/framework/ir/fused_attention_pass.h
paddle/fluid/framework/ir/fused_attention_pass.h
+176
-0
paddle/fluid/pybind/parallel_executor.cc
paddle/fluid/pybind/parallel_executor.cc
+26
-0
python/paddle/distributed/passes/cpp_pass.py
python/paddle/distributed/passes/cpp_pass.py
+13
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/test_fused_attention_pass.py
...paddle/fluid/tests/unittests/test_fused_attention_pass.py
+164
-0
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
b0ece266
...
@@ -382,6 +382,7 @@ set(IR_PASS_DEPS
...
@@ -382,6 +382,7 @@ set(IR_PASS_DEPS
graph_to_program_pass
graph_to_program_pass
fix_op_run_order_pass
fix_op_run_order_pass
fuse_gemm_epilogue_pass
fuse_gemm_epilogue_pass
fused_attention_pass
delete_dropout_op_pass
)
delete_dropout_op_pass
)
if
(
WITH_CINN
)
if
(
WITH_CINN
)
...
...
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
b0ece266
...
@@ -187,6 +187,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -187,6 +187,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPassWithCheck
(
strategy_
.
enable_auto_fusion_
,
"fusion_group_pass"
);
AppendPassWithCheck
(
strategy_
.
enable_auto_fusion_
,
"fusion_group_pass"
);
#endif
#endif
#ifdef PADDLE_WITH_CUDA
AppendPassWithCheck
(
strategy_
.
fused_attention_
,
"fused_attention_pass"
);
#endif
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060)
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060)
AppendPassWithCheck
(
strategy_
.
fuse_gemm_epilogue_
,
AppendPassWithCheck
(
strategy_
.
fuse_gemm_epilogue_
,
"fuse_gemm_epilogue_pass"
);
"fuse_gemm_epilogue_pass"
);
...
@@ -519,6 +523,9 @@ USE_PASS(fuse_all_reduce_op_pass);
...
@@ -519,6 +523,9 @@ USE_PASS(fuse_all_reduce_op_pass);
USE_PASS
(
runtime_context_cache_pass
);
USE_PASS
(
runtime_context_cache_pass
);
USE_PASS
(
add_reader_dependency_pass
);
USE_PASS
(
add_reader_dependency_pass
);
USE_PASS
(
delete_dropout_op_x_pass
);
USE_PASS
(
delete_dropout_op_x_pass
);
#ifdef PADDLE_WITH_CUDA
USE_PASS
(
fused_attention_pass
);
#endif
#ifdef PADDLE_WITH_CINN
#ifdef PADDLE_WITH_CINN
USE_PASS
(
build_cinn_pass
);
USE_PASS
(
build_cinn_pass
);
#endif
#endif
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
b0ece266
...
@@ -129,6 +129,8 @@ struct BuildStrategy {
...
@@ -129,6 +129,8 @@ struct BuildStrategy {
bool
sync_batch_norm_
{
false
};
bool
sync_batch_norm_
{
false
};
// Fuse GEMM+Epilogue via cublasLt epilogue.
// Fuse GEMM+Epilogue via cublasLt epilogue.
bool
fuse_gemm_epilogue_
{
false
};
bool
fuse_gemm_epilogue_
{
false
};
// Fused multi head attention
bool
fused_attention_
{
false
};
// mkldnn_enabled_op_types specify the operator type list to
// mkldnn_enabled_op_types specify the operator type list to
// use MKLDNN acceleration. It is null in default, means
// use MKLDNN acceleration. It is null in default, means
...
@@ -261,6 +263,7 @@ inline std::ostream &operator<<(std::ostream &os,
...
@@ -261,6 +263,7 @@ inline std::ostream &operator<<(std::ostream &os,
os
<<
"fuse_broadcast_ops_: "
<<
strategy
.
fuse_broadcast_ops_
<<
std
::
endl
;
os
<<
"fuse_broadcast_ops_: "
<<
strategy
.
fuse_broadcast_ops_
<<
std
::
endl
;
os
<<
"sync_batch_norm_: "
<<
strategy
.
sync_batch_norm_
<<
std
::
endl
;
os
<<
"sync_batch_norm_: "
<<
strategy
.
sync_batch_norm_
<<
std
::
endl
;
os
<<
"fuse_gemm_epilogue_: "
<<
strategy
.
fuse_gemm_epilogue_
<<
std
::
endl
;
os
<<
"fuse_gemm_epilogue_: "
<<
strategy
.
fuse_gemm_epilogue_
<<
std
::
endl
;
os
<<
"fused_attention_: "
<<
strategy
.
fused_attention_
<<
std
::
endl
;
os
<<
"mkldnn_enabled_op_types_: "
;
os
<<
"mkldnn_enabled_op_types_: "
;
for
(
auto
str
:
strategy
.
mkldnn_enabled_op_types_
)
{
for
(
auto
str
:
strategy
.
mkldnn_enabled_op_types_
)
{
os
<<
str
<<
", "
;
os
<<
str
<<
", "
;
...
...
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
b0ece266
...
@@ -124,6 +124,7 @@ message BuildStrategy {
...
@@ -124,6 +124,7 @@ message BuildStrategy {
optional
int32
reduce_strategy
=
15
[
default
=
0
];
optional
int32
reduce_strategy
=
15
[
default
=
0
];
optional
bool
fuse_gemm_epilogue
=
16
[
default
=
false
];
optional
bool
fuse_gemm_epilogue
=
16
[
default
=
false
];
optional
string
debug_graphviz_path
=
17
;
optional
string
debug_graphviz_path
=
17
;
optional
bool
fused_attention
=
18
[
default
=
false
];
}
}
message
ExecutionStrategy
{
message
ExecutionStrategy
{
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
b0ece266
...
@@ -223,6 +223,10 @@ cc_library(
...
@@ -223,6 +223,10 @@ cc_library(
fuse_gemm_epilogue_pass
fuse_gemm_epilogue_pass
SRCS fuse_gemm_epilogue_pass.cc
SRCS fuse_gemm_epilogue_pass.cc
DEPS pass graph_pattern_detector
)
DEPS pass graph_pattern_detector
)
cc_library
(
fused_attention_pass
SRCS fused_attention_pass.cc
DEPS pass graph_pattern_detector
)
cc_library
(
cc_library
(
fuse_relu_depthwise_conv_pass
fuse_relu_depthwise_conv_pass
SRCS fuse_relu_depthwise_conv_pass.cc
SRCS fuse_relu_depthwise_conv_pass.cc
...
...
paddle/fluid/framework/ir/fused_attention_pass.cc
0 → 100644
浏览文件 @
b0ece266
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fused_attention_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
PDNode
*
FusedAttentionPattern
::
operator
()(
PDNode
*
x
,
bool
pre_layer_norm
,
bool
post_layer_norm
,
bool
has_attn_mask
,
bool
do_dropout
,
bool
add_residual
)
{
// pre layer norm pattern
PDNode
*
pre_layer_norm_out_node
{
nullptr
};
if
(
pre_layer_norm
)
{
auto
*
pre_layer_norm_node
=
pattern
->
NewNode
(
pre_layer_norm_op_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
pre_layer_norm_scale_node
=
pattern
->
NewNode
(
pre_layer_norm_scale_repr
())
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
pre_layer_norm_bias_node
=
pattern
->
NewNode
(
pre_layer_norm_bias_repr
())
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
pre_layer_norm_out_node
=
pattern
->
NewNode
(
pre_layer_norm_out_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Y"
);
auto
*
pre_layer_norm_mean_node
=
pattern
->
NewNode
(
pre_layer_norm_mean_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
pre_layer_norm_variance_node
=
pattern
->
NewNode
(
pre_layer_norm_variance_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
pre_layer_norm_node
->
LinksFrom
({
x
,
pre_layer_norm_scale_node
,
pre_layer_norm_bias_node
})
.
LinksTo
({
pre_layer_norm_out_node
,
pre_layer_norm_mean_node
,
pre_layer_norm_variance_node
});
}
// fuse qkv pattern
auto
*
fuse_qkv_matmul_node
=
pattern
->
NewNode
(
fuse_qkv_matmul_op_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
fuse_qkv_matmul_w_node
=
pattern
->
NewNode
(
fuse_qkv_matmul_w_repr
())
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
fuse_qkv_matmul_out_node
=
pattern
->
NewNode
(
fuse_qkv_matmul_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
if
(
pre_layer_norm
)
{
pre_layer_norm_out_node
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
fuse_qkv_matmul_node
->
LinksFrom
({
pre_layer_norm_out_node
,
fuse_qkv_matmul_w_node
})
.
LinksTo
({
fuse_qkv_matmul_out_node
});
}
else
{
fuse_qkv_matmul_node
->
LinksFrom
({
x
,
fuse_qkv_matmul_w_node
})
.
LinksTo
({
fuse_qkv_matmul_out_node
});
}
auto
*
fuse_qkv_ele_add_node
=
pattern
->
NewNode
(
fuse_qkv_ele_add_op_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
fuse_qkv_ele_add_bias_node
=
pattern
->
NewNode
(
fuse_qkv_ele_add_bias_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
fuse_qkv_ele_add_out_node
=
pattern
->
NewNode
(
fuse_qkv_ele_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
fuse_qkv_matmul_out_node
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
fuse_qkv_ele_add_node
->
LinksFrom
({
fuse_qkv_matmul_out_node
,
fuse_qkv_ele_add_bias_node
})
.
LinksTo
({
fuse_qkv_ele_add_out_node
});
auto
*
fuse_qkv_reshape_node
=
pattern
->
NewNode
(
fuse_qkv_reshape_op_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
fuse_qkv_reshape_x_shape_node
=
pattern
->
NewNode
(
fuse_qkv_reshape_x_shape_repr
())
->
assert_is_op_output
(
"reshape2"
,
"XShape"
);
auto
*
fuse_qkv_reshape_out_node
=
pattern
->
NewNode
(
fuse_qkv_reshape_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
fuse_qkv_ele_add_out_node
->
assert_is_op_input
(
"reshape2"
,
"X"
);
fuse_qkv_reshape_node
->
LinksFrom
({
fuse_qkv_ele_add_out_node
})
.
LinksTo
({
fuse_qkv_reshape_x_shape_node
,
fuse_qkv_reshape_out_node
});
auto
*
fuse_qkv_transpose_node
=
pattern
->
NewNode
(
fuse_qkv_transpose_op_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
fuse_qkv_transpose_x_shape_node
=
pattern
->
NewNode
(
fuse_qkv_transpose_x_shape_repr
())
->
assert_is_op_output
(
"transpose2"
,
"XShape"
);
auto
*
fuse_qkv_transpose_out_node
=
pattern
->
NewNode
(
fuse_qkv_transpose_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
fuse_qkv_reshape_out_node
->
assert_is_op_input
(
"transpose2"
,
"X"
);
fuse_qkv_transpose_node
->
LinksFrom
({
fuse_qkv_reshape_out_node
})
.
LinksTo
({
fuse_qkv_transpose_x_shape_node
,
fuse_qkv_transpose_out_node
});
auto
*
fuse_qkv_split_node
=
pattern
->
NewNode
(
fuse_qkv_split_op_repr
())
->
assert_is_op
(
"split"
);
auto
*
fuse_qkv_split_out_q_node
=
pattern
->
NewNode
(
fuse_qkv_split_out_q_repr
())
->
assert_is_op_output
(
"split"
);
auto
*
fuse_qkv_split_out_k_node
=
pattern
->
NewNode
(
fuse_qkv_split_out_k_repr
())
->
assert_is_op_output
(
"split"
);
auto
*
fuse_qkv_split_out_v_node
=
pattern
->
NewNode
(
fuse_qkv_split_out_v_repr
())
->
assert_is_op_output
(
"split"
);
fuse_qkv_transpose_out_node
->
assert_is_op_input
(
"split"
,
"X"
);
fuse_qkv_split_node
->
LinksFrom
({
fuse_qkv_transpose_out_node
})
.
LinksTo
({
fuse_qkv_split_out_q_node
,
fuse_qkv_split_out_k_node
,
fuse_qkv_split_out_v_node
});
// core attention pattern
auto
*
qk_matmul_node
=
pattern
->
NewNode
(
qk_matmul_op_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
qk_matmul_out_node
=
pattern
->
NewNode
(
qk_matmul_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
fuse_qkv_split_out_q_node
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
fuse_qkv_split_out_k_node
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
qk_matmul_node
->
LinksFrom
({
fuse_qkv_split_out_q_node
,
fuse_qkv_split_out_k_node
})
.
LinksTo
({
qk_matmul_out_node
});
auto
*
qk_scale_node
=
pattern
->
NewNode
(
qk_scale_op_repr
())
->
assert_is_op
(
"scale"
);
auto
*
qk_scale_out_node
=
pattern
->
NewNode
(
qk_scale_out_repr
())
->
assert_is_op_output
(
"scale"
);
qk_matmul_out_node
->
assert_is_op_input
(
"scale"
,
"X"
);
qk_scale_node
->
LinksFrom
({
qk_matmul_out_node
}).
LinksTo
({
qk_scale_out_node
});
PDNode
*
add_mask_ele_add_out_node
{
nullptr
};
if
(
has_attn_mask
)
{
auto
*
add_mask_ele_add_node
=
pattern
->
NewNode
(
add_mask_ele_add_op_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
add_mask_ele_add_mask_node
=
pattern
->
NewNode
(
add_mask_ele_add_mask_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
add_mask_ele_add_out_node
=
pattern
->
NewNode
(
add_mask_ele_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
qk_scale_out_node
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
add_mask_ele_add_node
->
LinksFrom
({
qk_scale_out_node
,
add_mask_ele_add_mask_node
})
.
LinksTo
({
add_mask_ele_add_out_node
});
}
auto
*
qk_softmax_node
=
pattern
->
NewNode
(
qk_softmax_op_repr
())
->
assert_is_op
(
"softmax"
);
auto
*
qk_softmax_out_node
=
pattern
->
NewNode
(
qk_softmax_out_repr
())
->
assert_is_op_output
(
"softmax"
);
if
(
has_attn_mask
)
{
add_mask_ele_add_out_node
->
assert_is_op_input
(
"softmax"
,
"X"
);
qk_softmax_node
->
LinksFrom
({
add_mask_ele_add_out_node
})
.
LinksTo
({
qk_softmax_out_node
});
}
else
{
qk_scale_out_node
->
assert_is_op_input
(
"softmax"
,
"X"
);
qk_softmax_node
->
LinksFrom
({
qk_scale_out_node
})
.
LinksTo
({
qk_softmax_out_node
});
}
PDNode
*
attn_dropout_out_node
{
nullptr
};
if
(
do_dropout
)
{
auto
*
attn_dropout_node
=
pattern
->
NewNode
(
attn_dropout_op_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
attn_dropout_mask_node
=
pattern
->
NewNode
(
attn_dropout_mask_repr
())
->
assert_is_op_output
(
"dropout"
,
"Mask"
);
attn_dropout_out_node
=
pattern
->
NewNode
(
attn_dropout_out_repr
())
->
assert_is_op_output
(
"dropout"
);
qk_softmax_out_node
->
assert_is_op_input
(
"dropout"
,
"X"
);
attn_dropout_node
->
LinksFrom
({
qk_softmax_out_node
})
.
LinksTo
({
attn_dropout_mask_node
,
attn_dropout_out_node
});
}
auto
*
qkv_matmul_node
=
pattern
->
NewNode
(
qkv_matmul_op_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
qkv_matmul_out_node
=
pattern
->
NewNode
(
qkv_matmul_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
fuse_qkv_split_out_v_node
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
if
(
do_dropout
)
{
attn_dropout_out_node
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
qkv_matmul_node
->
LinksFrom
({
attn_dropout_out_node
,
fuse_qkv_split_out_v_node
})
.
LinksTo
({
qkv_matmul_out_node
});
}
else
{
qk_softmax_out_node
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
qkv_matmul_node
->
LinksFrom
({
qk_softmax_out_node
,
fuse_qkv_split_out_v_node
})
.
LinksTo
({
qkv_matmul_out_node
});
}
auto
*
qkv_transpose_node
=
pattern
->
NewNode
(
qkv_transpose_op_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
qkv_transpose_x_shape_node
=
pattern
->
NewNode
(
qkv_transpose_x_shape_repr
())
->
assert_is_op_output
(
"transpose2"
,
"XShape"
);
auto
*
qkv_transpose_out_node
=
pattern
->
NewNode
(
qkv_transpose_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
qkv_matmul_out_node
->
assert_is_op_input
(
"transpose2"
,
"X"
);
qkv_transpose_node
->
LinksFrom
({
qkv_matmul_out_node
})
.
LinksTo
({
qkv_transpose_x_shape_node
,
qkv_transpose_out_node
});
auto
*
qkv_reshape_node
=
pattern
->
NewNode
(
qkv_reshape_op_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
qkv_reshape_x_shape_node
=
pattern
->
NewNode
(
qkv_reshape_x_shape_repr
())
->
assert_is_op_output
(
"reshape2"
,
"XShape"
);
auto
*
qkv_reshape_out_node
=
pattern
->
NewNode
(
qkv_reshape_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
qkv_transpose_out_node
->
assert_is_op_input
(
"reshape2"
,
"X"
);
qkv_reshape_node
->
LinksFrom
({
qkv_transpose_out_node
})
.
LinksTo
({
qkv_reshape_x_shape_node
,
qkv_reshape_out_node
});
// out linear pattern
auto
*
out_linear_matmul_node
=
pattern
->
NewNode
(
out_linear_matmul_op_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
out_linear_matmul_w_node
=
pattern
->
NewNode
(
out_linear_matmul_w_repr
())
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
out_linear_matmul_out_node
=
pattern
->
NewNode
(
out_linear_matmul_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
qkv_reshape_out_node
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
out_linear_matmul_node
->
LinksFrom
({
qkv_reshape_out_node
,
out_linear_matmul_w_node
})
.
LinksTo
({
out_linear_matmul_out_node
});
auto
*
out_linear_ele_add_node
=
pattern
->
NewNode
(
out_linear_ele_add_op_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
out_linear_ele_add_bias_node
=
pattern
->
NewNode
(
out_linear_ele_add_bias_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
out_linear_ele_add_out_node
=
pattern
->
NewNode
(
out_linear_ele_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
out_linear_matmul_out_node
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
out_linear_ele_add_node
->
LinksFrom
({
out_linear_matmul_out_node
,
out_linear_ele_add_bias_node
})
.
LinksTo
({
out_linear_ele_add_out_node
});
auto
*
out_linear_dropout_node
=
pattern
->
NewNode
(
out_linear_dropout_op_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
out_linear_dropout_mask_node
=
pattern
->
NewNode
(
out_linear_dropout_mask_repr
())
->
assert_is_op_output
(
"dropout"
,
"Mask"
);
auto
*
out_linear_dropout_out_node
=
pattern
->
NewNode
(
out_linear_dropout_out_repr
())
->
assert_is_op_output
(
"dropout"
);
out_linear_ele_add_out_node
->
assert_is_op_input
(
"dropout"
,
"X"
);
out_linear_dropout_node
->
LinksFrom
({
out_linear_ele_add_out_node
})
.
LinksTo
({
out_linear_dropout_mask_node
,
out_linear_dropout_out_node
});
if
(
!
add_residual
&&
!
post_layer_norm
)
{
return
out_linear_dropout_out_node
;
}
// add residual
PDNode
*
residual_ele_add_out_node
{
nullptr
};
if
(
add_residual
)
{
// this kind of pattern only support `residual + dropout_out`, since we have
// to fix X and Y
auto
*
residual_ele_add_node
=
pattern
->
NewNode
(
residual_ele_add_op_repr
())
->
assert_is_op
(
"elementwise_add"
);
residual_ele_add_out_node
=
pattern
->
NewNode
(
residual_ele_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
out_linear_dropout_out_node
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
residual_ele_add_node
->
LinksFrom
({
x
,
out_linear_dropout_out_node
})
.
LinksTo
({
residual_ele_add_out_node
});
if
(
!
post_layer_norm
)
{
return
residual_ele_add_out_node
;
}
}
// post layer norm
auto
*
post_layer_norm_node
=
pattern
->
NewNode
(
post_layer_norm_op_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
post_layer_norm_scale_node
=
pattern
->
NewNode
(
post_layer_norm_scale_repr
())
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
post_layer_norm_bias_node
=
pattern
->
NewNode
(
post_layer_norm_bias_repr
())
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
post_layer_norm_out_node
=
pattern
->
NewNode
(
post_layer_norm_out_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Y"
);
auto
*
post_layer_norm_mean_node
=
pattern
->
NewNode
(
post_layer_norm_mean_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
post_layer_norm_variance_node
=
pattern
->
NewNode
(
post_layer_norm_variance_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
if
(
add_residual
)
{
residual_ele_add_out_node
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
post_layer_norm_node
->
LinksFrom
({
residual_ele_add_out_node
,
post_layer_norm_scale_node
,
post_layer_norm_bias_node
})
.
LinksTo
({
post_layer_norm_out_node
,
post_layer_norm_mean_node
,
post_layer_norm_variance_node
});
}
else
{
out_linear_dropout_out_node
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
post_layer_norm_node
->
LinksFrom
({
out_linear_dropout_out_node
,
post_layer_norm_scale_node
,
post_layer_norm_bias_node
})
.
LinksTo
({
post_layer_norm_out_node
,
post_layer_norm_mean_node
,
post_layer_norm_variance_node
});
}
return
post_layer_norm_out_node
;
}
PDNode
*
FusedAttentionGradPattern
::
operator
()(
PDNode
*
x
,
bool
pre_layer_norm
,
bool
post_layer_norm
,
bool
has_attn_mask
,
bool
do_dropout
,
bool
add_residual
)
{
// TODO(Yuang Liu): finish the backward pattern
return
nullptr
;
}
}
// namespace patterns
void
FusedAttentionsPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
graph
=
PreMaskDropResPostFwd
(
graph
);
graph
=
PreMaskDropResPostBwd
(
graph
);
}
ir
::
Graph
*
FusedAttentionsPass
::
PreMaskDropResPostFwd
(
Graph
*
graph
)
const
{
GraphPatternDetector
gpd
;
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
patterns
::
PDNodeName
(
name_scope_
,
"x"
))
->
AsInput
()
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
patterns
::
FusedAttentionPattern
fused_attention_pattern
(
gpd
.
mutable_pattern
(),
"fused_attention_pattern"
);
fused_attention_pattern
(
x
,
/* pre_layer_norm */
true
,
/* post_layer_norm */
true
,
/* has_attn_mask */
true
,
/* do_dropout */
true
,
/* add_residual */
true
);
int
found_fused_attention
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
3
)
<<
"handle FusedMultiHeadAttention pass's fusion"
;
GET_IR_NODE_FROM_SUBGRAPH
(
pre_layer_norm_op_node
,
pre_layer_norm_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_matmul_op_node
,
fuse_qkv_matmul_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_ele_add_op_node
,
fuse_qkv_ele_add_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_reshape_op_node
,
fuse_qkv_reshape_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_transpose_op_node
,
fuse_qkv_transpose_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_split_op_node
,
fuse_qkv_split_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qk_matmul_op_node
,
qk_matmul_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qk_scale_op_node
,
qk_scale_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
add_mask_ele_add_op_node
,
add_mask_ele_add_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qk_softmax_op_node
,
qk_softmax_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
attn_dropout_op_node
,
attn_dropout_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qkv_matmul_op_node
,
qkv_matmul_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qkv_transpose_op_node
,
qkv_transpose_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qkv_reshape_op_node
,
qkv_reshape_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_matmul_op_node
,
out_linear_matmul_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_ele_add_op_node
,
out_linear_ele_add_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_dropout_op_node
,
out_linear_dropout_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
residual_ele_add_op_node
,
residual_ele_add_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
post_layer_norm_op_node
,
post_layer_norm_op
,
fused_attention_pattern
);
// TODO(Yuang Liu): finish the handler
GraphSafeRemoveNodes
(
g
,
{
pre_layer_norm_op_node
,
fuse_qkv_matmul_op_node
,
fuse_qkv_ele_add_op_node
,
fuse_qkv_reshape_op_node
,
fuse_qkv_transpose_op_node
,
fuse_qkv_split_op_node
,
qk_matmul_op_node
,
qk_scale_op_node
,
add_mask_ele_add_op_node
,
qk_softmax_op_node
,
attn_dropout_op_node
,
qkv_matmul_op_node
,
qkv_transpose_op_node
,
qkv_reshape_op_node
,
out_linear_matmul_op_node
,
out_linear_ele_add_op_node
,
out_linear_dropout_op_node
,
residual_ele_add_op_node
,
post_layer_norm_op_node
});
found_fused_attention
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_fused_attention
);
return
graph
;
}
ir
::
Graph
*
FusedAttentionsPass
::
PreMaskDropResPostBwd
(
Graph
*
graph
)
const
{
// TODO(Yuang Liu): finish the pass
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
fused_attention_pass
,
paddle
::
framework
::
ir
::
FusedAttentionsPass
);
paddle/fluid/framework/ir/fused_attention_pass.h
0 → 100644
浏览文件 @
b0ece266
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
// Declare patterns for multi head attention.
// Can detect:
// 1. Pre layer norm, post layer norm or sandwich layer norm.
// 2. Add attn mask for qk product before the softmax or not.
// 3. Do attn dropout or not.
// 4. Add residual to the out linear result or not.
struct
FusedAttentionPattern
:
public
PatternBase
{
FusedAttentionPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fused_attention_pattern"
)
{}
PDNode
*
operator
()(
PDNode
*
x
,
bool
pre_layer_norm
,
// do pre ln or not
bool
post_layer_norm
,
// do post ln or not
bool
has_attn_mask
,
// add attn mask to qk or not
bool
do_dropout
,
// dropout the softmax(qk) or not
bool
add_residual
);
// add residual to out linear or not
// pre layer norm
PATTERN_DECL_NODE
(
pre_layer_norm_op
);
PATTERN_DECL_NODE
(
pre_layer_norm_scale
);
PATTERN_DECL_NODE
(
pre_layer_norm_bias
);
PATTERN_DECL_NODE
(
pre_layer_norm_out
);
PATTERN_DECL_NODE
(
pre_layer_norm_mean
);
PATTERN_DECL_NODE
(
pre_layer_norm_variance
);
// fuse qkv projection
PATTERN_DECL_NODE
(
fuse_qkv_matmul_op
);
PATTERN_DECL_NODE
(
fuse_qkv_matmul_w
);
PATTERN_DECL_NODE
(
fuse_qkv_matmul_out
);
PATTERN_DECL_NODE
(
fuse_qkv_ele_add_op
);
PATTERN_DECL_NODE
(
fuse_qkv_ele_add_bias
);
PATTERN_DECL_NODE
(
fuse_qkv_ele_add_out
);
PATTERN_DECL_NODE
(
fuse_qkv_reshape_op
);
PATTERN_DECL_NODE
(
fuse_qkv_reshape_out
);
PATTERN_DECL_NODE
(
fuse_qkv_reshape_x_shape
);
PATTERN_DECL_NODE
(
fuse_qkv_transpose_op
);
PATTERN_DECL_NODE
(
fuse_qkv_transpose_out
);
PATTERN_DECL_NODE
(
fuse_qkv_transpose_x_shape
);
PATTERN_DECL_NODE
(
fuse_qkv_split_op
);
PATTERN_DECL_NODE
(
fuse_qkv_split_out_q
);
// q
PATTERN_DECL_NODE
(
fuse_qkv_split_out_k
);
// k
PATTERN_DECL_NODE
(
fuse_qkv_split_out_v
);
// v
// core attention
PATTERN_DECL_NODE
(
qk_matmul_op
);
PATTERN_DECL_NODE
(
qk_matmul_out
);
PATTERN_DECL_NODE
(
qk_scale_op
);
PATTERN_DECL_NODE
(
qk_scale_out
);
PATTERN_DECL_NODE
(
add_mask_ele_add_op
);
PATTERN_DECL_NODE
(
add_mask_ele_add_mask
);
PATTERN_DECL_NODE
(
add_mask_ele_add_out
);
PATTERN_DECL_NODE
(
qk_softmax_op
);
PATTERN_DECL_NODE
(
qk_softmax_out
);
PATTERN_DECL_NODE
(
attn_dropout_op
);
PATTERN_DECL_NODE
(
attn_dropout_out
);
PATTERN_DECL_NODE
(
attn_dropout_mask
);
PATTERN_DECL_NODE
(
qkv_matmul_op
);
PATTERN_DECL_NODE
(
qkv_matmul_out
);
PATTERN_DECL_NODE
(
qkv_transpose_op
);
PATTERN_DECL_NODE
(
qkv_transpose_out
);
PATTERN_DECL_NODE
(
qkv_transpose_x_shape
);
PATTERN_DECL_NODE
(
qkv_reshape_op
);
PATTERN_DECL_NODE
(
qkv_reshape_out
);
PATTERN_DECL_NODE
(
qkv_reshape_x_shape
);
// out linear
PATTERN_DECL_NODE
(
out_linear_matmul_op
);
PATTERN_DECL_NODE
(
out_linear_matmul_w
);
PATTERN_DECL_NODE
(
out_linear_matmul_out
);
PATTERN_DECL_NODE
(
out_linear_ele_add_op
);
PATTERN_DECL_NODE
(
out_linear_ele_add_bias
);
PATTERN_DECL_NODE
(
out_linear_ele_add_out
);
PATTERN_DECL_NODE
(
out_linear_dropout_op
);
PATTERN_DECL_NODE
(
out_linear_dropout_out
);
PATTERN_DECL_NODE
(
out_linear_dropout_mask
);
// residual
PATTERN_DECL_NODE
(
residual_ele_add_op
);
PATTERN_DECL_NODE
(
residual_ele_add_out
);
// post layer norm
PATTERN_DECL_NODE
(
post_layer_norm_op
);
PATTERN_DECL_NODE
(
post_layer_norm_scale
);
PATTERN_DECL_NODE
(
post_layer_norm_bias
);
PATTERN_DECL_NODE
(
post_layer_norm_out
);
PATTERN_DECL_NODE
(
post_layer_norm_mean
);
PATTERN_DECL_NODE
(
post_layer_norm_variance
);
};
// Declare the grad pattern for multi head attention
struct
FusedAttentionGradPattern
:
public
PatternBase
{
FusedAttentionGradPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fused_attention_pattern"
)
{}
PDNode
*
operator
()(
PDNode
*
x
,
bool
pre_layer_norm
,
// pre ln
bool
post_layer_norm
,
// post ln
bool
has_attn_mask
,
// add attn mask to qk or not
bool
do_dropout
,
// dropout the softmax(qk) or not
bool
add_residual
);
// add residual to out linear or not
// TODO(Yuang Liu): add backward pattern
};
}
// namespace patterns
class
FusedAttentionsPass
:
public
FusePassBase
{
public:
virtual
~
FusedAttentionsPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"fused_attention_pass"
};
private:
// The name rule for the helper function.
// The function name will contain at most five parts in order:
// 1. Do pre layer norm? [Pre]
// 2. Add mask in the core attention part? [Mask]
// 3. Do dropout in the core attention part? [Drop]
// 4. Add residual? [Res]
// 5. Do post layer norm? [Post]
// 6. Forward or Backward? [Fwd/Bwd]
// If true, the function name will have an abbreviation part.
// If false, the function name won't contain an abbreviation for it.
ir
::
Graph
*
PreMaskDropResPostFwd
(
Graph
*
graph
)
const
;
ir
::
Graph
*
PreMaskDropResPostBwd
(
Graph
*
graph
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/pybind/parallel_executor.cc
浏览文件 @
b0ece266
...
@@ -714,6 +714,32 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT
...
@@ -714,6 +714,32 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT
build_strategy = static.BuildStrategy()
build_strategy = static.BuildStrategy()
build_strategy.fuse_gemm_epilogue = True
build_strategy.fuse_gemm_epilogue = True
)DOC"
)
)DOC"
)
.
def_property
(
"fused_attention"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
fused_attention_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
PADDLE_ENFORCE_NE
(
self
.
IsFinalized
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"BuildStrategy has been finlaized, cannot be "
"configured again."
));
self
.
fused_attention_
=
b
;
},
R"DOC((bool, optional): fused_attention indicate whether
to fuse the whole multi head attention part with one op,
it may make the execution faster. Default is False.
Examples:
.. code-block:: python
import paddle
import paddle.static as static
paddle.enable_static()
build_strategy = static.BuildStrategy()
build_strategy.fused_attention = True
)DOC"
)
.
def_property
(
.
def_property
(
"fuse_bn_act_ops"
,
"fuse_bn_act_ops"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
fuse_bn_act_ops_
;
},
[](
const
BuildStrategy
&
self
)
{
return
self
.
fuse_bn_act_ops_
;
},
...
...
python/paddle/distributed/passes/cpp_pass.py
浏览文件 @
b0ece266
...
@@ -71,6 +71,19 @@ class FuseReluDepthwiseConvPass(CPPPassWrapper):
...
@@ -71,6 +71,19 @@ class FuseReluDepthwiseConvPass(CPPPassWrapper):
return
PassType
.
FUSION_OPT
return
PassType
.
FUSION_OPT
@
register_pass
(
"fused_attention"
)
class
FusedAttentionPass
(
CPPPassWrapper
):
def
__init__
(
self
):
super
().
__init__
()
@
property
def
cpp_name
(
self
):
return
"fused_attention_pass"
def
_type
(
self
):
return
PassType
.
FUSION_OPT
@
register_pass
(
"fuse_gemm_epilogue"
)
@
register_pass
(
"fuse_gemm_epilogue"
)
class
FuseGemmEpiloguePass
(
CPPPassWrapper
):
class
FuseGemmEpiloguePass
(
CPPPassWrapper
):
def
__init__
(
self
):
def
__init__
(
self
):
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
b0ece266
...
@@ -76,6 +76,7 @@ if(NOT WITH_GPU)
...
@@ -76,6 +76,7 @@ if(NOT WITH_GPU)
list
(
REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer
)
list
(
REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer
)
list
(
REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op
)
list
(
REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op
)
list
(
REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api
)
list
(
REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api
)
list
(
REMOVE_ITEM TEST_OPS test_fused_attention_pass
)
endif
()
endif
()
list
(
REMOVE_ITEM TEST_OPS test_fused_ec_moe_op
)
list
(
REMOVE_ITEM TEST_OPS test_fused_ec_moe_op
)
...
...
python/paddle/fluid/tests/unittests/test_fused_attention_pass.py
0 → 100644
浏览文件 @
b0ece266
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle
import
paddle.fluid.core
as
core
import
paddle.nn.functional
as
F
from
paddle.distributed.passes
import
PassManager
,
new_pass
paddle
.
enable_static
()
class
MultiHeadAttention
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
embed_dim
,
num_heads
,
add_residual
=
True
,
pre_ln
=
True
,
post_ln
=
False
,
attn_dropout
=
True
,
):
super
(
MultiHeadAttention
,
self
).
__init__
()
self
.
embed_dim
=
embed_dim
self
.
kdim
=
embed_dim
self
.
vdim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
add_residual
=
add_residual
self
.
pre_ln
=
pre_ln
self
.
post_ln
=
post_ln
self
.
attn_dropout
=
attn_dropout
self
.
head_dim
=
embed_dim
//
num_heads
assert
(
self
.
head_dim
*
num_heads
==
self
.
embed_dim
),
"embed_dim must be divisible by num_heads"
self
.
norm1
=
paddle
.
nn
.
LayerNorm
(
embed_dim
,
epsilon
=
1e-5
)
self
.
norm2
=
paddle
.
nn
.
LayerNorm
(
embed_dim
,
epsilon
=
1e-5
)
self
.
qkv_proj
=
paddle
.
nn
.
Linear
(
embed_dim
,
3
*
embed_dim
)
self
.
out_proj
=
paddle
.
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
dropout
=
paddle
.
nn
.
Dropout
(
0.1
,
mode
=
"upscale_in_train"
)
def
forward
(
self
,
x
,
attn_mask
=
None
):
residual
=
x
if
self
.
pre_ln
:
# pre layer norm
x
=
self
.
norm1
(
x
)
# compute qkv
qkv
=
self
.
qkv_proj
(
x
)
qkv
=
paddle
.
reshape
(
qkv
,
[
0
,
0
,
self
.
num_heads
,
3
*
self
.
head_dim
])
qkv
=
paddle
.
transpose
(
qkv
,
[
0
,
2
,
1
,
3
])
q
,
k
,
v
=
paddle
.
split
(
qkv
,
num_or_sections
=
3
,
axis
=-
1
)
# compute core attention
product
=
paddle
.
matmul
(
x
=
q
,
y
=
k
,
transpose_y
=
True
)
product
=
paddle
.
scale
(
product
,
scale
=
self
.
head_dim
**-
0.5
)
if
attn_mask
is
not
None
:
product
=
product
+
attn_mask
weights
=
F
.
softmax
(
product
)
if
self
.
attn_dropout
:
weights
=
F
.
dropout
(
weights
,
0.1
,
training
=
self
.
training
,
mode
=
"upscale_in_train"
)
out
=
paddle
.
matmul
(
weights
,
v
)
out
=
paddle
.
transpose
(
out
,
perm
=
[
0
,
2
,
1
,
3
])
out
=
paddle
.
reshape
(
x
=
out
,
shape
=
[
0
,
0
,
out
.
shape
[
2
]
*
out
.
shape
[
3
]])
# project to output
out
=
self
.
out_proj
(
out
)
out
=
self
.
dropout
(
out
)
if
self
.
add_residual
:
out
=
residual
+
out
if
self
.
post_ln
:
# post layer norm
out
=
self
.
norm2
(
out
)
return
out
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestFusedAttentionPass
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
add_residual
=
True
self
.
pre_ln
=
True
self
.
post_ln
=
True
self
.
attn_dropout
=
True
self
.
add_mask
=
True
def
test_pass
(
self
):
batch_size
=
2
seq_len
=
1024
hidden_size
=
768
num_heads
=
12
x_data
=
np
.
random
.
rand
(
batch_size
,
seq_len
,
hidden_size
).
astype
(
'float32'
)
mask_data
=
np
.
random
.
rand
(
batch_size
,
num_heads
,
seq_len
,
seq_len
).
astype
(
'float32'
)
main_prog
=
paddle
.
static
.
Program
()
startup_prog
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main_prog
,
startup_prog
):
data
=
paddle
.
static
.
data
(
name
=
"x"
,
shape
=
[
-
1
,
seq_len
,
hidden_size
],
dtype
=
'float32'
,
)
if
self
.
add_mask
:
attn_mask
=
paddle
.
static
.
data
(
name
=
"attn_mask"
,
shape
=
[
-
1
,
num_heads
,
seq_len
,
seq_len
],
dtype
=
'float32'
,
)
else
:
attn_mask
=
None
multi_head_attn
=
MultiHeadAttention
(
hidden_size
,
num_heads
,
add_residual
=
self
.
add_residual
,
pre_ln
=
self
.
pre_ln
,
post_ln
=
self
.
post_ln
,
attn_dropout
=
self
.
attn_dropout
,
)
out
=
multi_head_attn
(
data
,
attn_mask
)
loss
=
paddle
.
mean
(
out
)
sgd_optimizer
=
paddle
.
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
sgd_optimizer
.
minimize
(
loss
)
pass_manager
=
PassManager
([
new_pass
(
"fused_attention"
)])
pass_manager
.
apply
([
main_prog
],
[
startup_prog
])
ops
=
main_prog
.
global_block
().
ops
assert
ops
[
0
].
type
==
'reduce_mean'
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
0
)
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录