Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b0ece266
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b0ece266
编写于
1月 10, 2023
作者:
Y
Yuang Liu
提交者:
GitHub
1月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Fuse attention pass] Forward pattern. (#49621)
上级
13992de7
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
844 addition
and
0 deletion
+844
-0
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+1
-0
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+7
-0
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+3
-0
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+1
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+4
-0
paddle/fluid/framework/ir/fused_attention_pass.cc
paddle/fluid/framework/ir/fused_attention_pass.cc
+448
-0
paddle/fluid/framework/ir/fused_attention_pass.h
paddle/fluid/framework/ir/fused_attention_pass.h
+176
-0
paddle/fluid/pybind/parallel_executor.cc
paddle/fluid/pybind/parallel_executor.cc
+26
-0
python/paddle/distributed/passes/cpp_pass.py
python/paddle/distributed/passes/cpp_pass.py
+13
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/test_fused_attention_pass.py
...paddle/fluid/tests/unittests/test_fused_attention_pass.py
+164
-0
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
b0ece266
...
@@ -382,6 +382,7 @@ set(IR_PASS_DEPS
...
@@ -382,6 +382,7 @@ set(IR_PASS_DEPS
graph_to_program_pass
graph_to_program_pass
fix_op_run_order_pass
fix_op_run_order_pass
fuse_gemm_epilogue_pass
fuse_gemm_epilogue_pass
fused_attention_pass
delete_dropout_op_pass
)
delete_dropout_op_pass
)
if
(
WITH_CINN
)
if
(
WITH_CINN
)
...
...
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
b0ece266
...
@@ -187,6 +187,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -187,6 +187,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPassWithCheck
(
strategy_
.
enable_auto_fusion_
,
"fusion_group_pass"
);
AppendPassWithCheck
(
strategy_
.
enable_auto_fusion_
,
"fusion_group_pass"
);
#endif
#endif
#ifdef PADDLE_WITH_CUDA
AppendPassWithCheck
(
strategy_
.
fused_attention_
,
"fused_attention_pass"
);
#endif
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060)
#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060)
AppendPassWithCheck
(
strategy_
.
fuse_gemm_epilogue_
,
AppendPassWithCheck
(
strategy_
.
fuse_gemm_epilogue_
,
"fuse_gemm_epilogue_pass"
);
"fuse_gemm_epilogue_pass"
);
...
@@ -519,6 +523,9 @@ USE_PASS(fuse_all_reduce_op_pass);
...
@@ -519,6 +523,9 @@ USE_PASS(fuse_all_reduce_op_pass);
USE_PASS
(
runtime_context_cache_pass
);
USE_PASS
(
runtime_context_cache_pass
);
USE_PASS
(
add_reader_dependency_pass
);
USE_PASS
(
add_reader_dependency_pass
);
USE_PASS
(
delete_dropout_op_x_pass
);
USE_PASS
(
delete_dropout_op_x_pass
);
#ifdef PADDLE_WITH_CUDA
USE_PASS
(
fused_attention_pass
);
#endif
#ifdef PADDLE_WITH_CINN
#ifdef PADDLE_WITH_CINN
USE_PASS
(
build_cinn_pass
);
USE_PASS
(
build_cinn_pass
);
#endif
#endif
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
b0ece266
...
@@ -129,6 +129,8 @@ struct BuildStrategy {
...
@@ -129,6 +129,8 @@ struct BuildStrategy {
bool
sync_batch_norm_
{
false
};
bool
sync_batch_norm_
{
false
};
// Fuse GEMM+Epilogue via cublasLt epilogue.
// Fuse GEMM+Epilogue via cublasLt epilogue.
bool
fuse_gemm_epilogue_
{
false
};
bool
fuse_gemm_epilogue_
{
false
};
// Fused multi head attention
bool
fused_attention_
{
false
};
// mkldnn_enabled_op_types specify the operator type list to
// mkldnn_enabled_op_types specify the operator type list to
// use MKLDNN acceleration. It is null in default, means
// use MKLDNN acceleration. It is null in default, means
...
@@ -261,6 +263,7 @@ inline std::ostream &operator<<(std::ostream &os,
...
@@ -261,6 +263,7 @@ inline std::ostream &operator<<(std::ostream &os,
os
<<
"fuse_broadcast_ops_: "
<<
strategy
.
fuse_broadcast_ops_
<<
std
::
endl
;
os
<<
"fuse_broadcast_ops_: "
<<
strategy
.
fuse_broadcast_ops_
<<
std
::
endl
;
os
<<
"sync_batch_norm_: "
<<
strategy
.
sync_batch_norm_
<<
std
::
endl
;
os
<<
"sync_batch_norm_: "
<<
strategy
.
sync_batch_norm_
<<
std
::
endl
;
os
<<
"fuse_gemm_epilogue_: "
<<
strategy
.
fuse_gemm_epilogue_
<<
std
::
endl
;
os
<<
"fuse_gemm_epilogue_: "
<<
strategy
.
fuse_gemm_epilogue_
<<
std
::
endl
;
os
<<
"fused_attention_: "
<<
strategy
.
fused_attention_
<<
std
::
endl
;
os
<<
"mkldnn_enabled_op_types_: "
;
os
<<
"mkldnn_enabled_op_types_: "
;
for
(
auto
str
:
strategy
.
mkldnn_enabled_op_types_
)
{
for
(
auto
str
:
strategy
.
mkldnn_enabled_op_types_
)
{
os
<<
str
<<
", "
;
os
<<
str
<<
", "
;
...
...
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
b0ece266
...
@@ -124,6 +124,7 @@ message BuildStrategy {
...
@@ -124,6 +124,7 @@ message BuildStrategy {
optional
int32
reduce_strategy
=
15
[
default
=
0
];
optional
int32
reduce_strategy
=
15
[
default
=
0
];
optional
bool
fuse_gemm_epilogue
=
16
[
default
=
false
];
optional
bool
fuse_gemm_epilogue
=
16
[
default
=
false
];
optional
string
debug_graphviz_path
=
17
;
optional
string
debug_graphviz_path
=
17
;
optional
bool
fused_attention
=
18
[
default
=
false
];
}
}
message
ExecutionStrategy
{
message
ExecutionStrategy
{
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
b0ece266
...
@@ -223,6 +223,10 @@ cc_library(
...
@@ -223,6 +223,10 @@ cc_library(
fuse_gemm_epilogue_pass
fuse_gemm_epilogue_pass
SRCS fuse_gemm_epilogue_pass.cc
SRCS fuse_gemm_epilogue_pass.cc
DEPS pass graph_pattern_detector
)
DEPS pass graph_pattern_detector
)
cc_library
(
fused_attention_pass
SRCS fused_attention_pass.cc
DEPS pass graph_pattern_detector
)
cc_library
(
cc_library
(
fuse_relu_depthwise_conv_pass
fuse_relu_depthwise_conv_pass
SRCS fuse_relu_depthwise_conv_pass.cc
SRCS fuse_relu_depthwise_conv_pass.cc
...
...
paddle/fluid/framework/ir/fused_attention_pass.cc
0 → 100644
浏览文件 @
b0ece266
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fused_attention_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
PDNode
*
FusedAttentionPattern
::
operator
()(
PDNode
*
x
,
bool
pre_layer_norm
,
bool
post_layer_norm
,
bool
has_attn_mask
,
bool
do_dropout
,
bool
add_residual
)
{
// pre layer norm pattern
PDNode
*
pre_layer_norm_out_node
{
nullptr
};
if
(
pre_layer_norm
)
{
auto
*
pre_layer_norm_node
=
pattern
->
NewNode
(
pre_layer_norm_op_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
pre_layer_norm_scale_node
=
pattern
->
NewNode
(
pre_layer_norm_scale_repr
())
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
pre_layer_norm_bias_node
=
pattern
->
NewNode
(
pre_layer_norm_bias_repr
())
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
pre_layer_norm_out_node
=
pattern
->
NewNode
(
pre_layer_norm_out_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Y"
);
auto
*
pre_layer_norm_mean_node
=
pattern
->
NewNode
(
pre_layer_norm_mean_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
pre_layer_norm_variance_node
=
pattern
->
NewNode
(
pre_layer_norm_variance_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
pre_layer_norm_node
->
LinksFrom
({
x
,
pre_layer_norm_scale_node
,
pre_layer_norm_bias_node
})
.
LinksTo
({
pre_layer_norm_out_node
,
pre_layer_norm_mean_node
,
pre_layer_norm_variance_node
});
}
// fuse qkv pattern
auto
*
fuse_qkv_matmul_node
=
pattern
->
NewNode
(
fuse_qkv_matmul_op_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
fuse_qkv_matmul_w_node
=
pattern
->
NewNode
(
fuse_qkv_matmul_w_repr
())
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
fuse_qkv_matmul_out_node
=
pattern
->
NewNode
(
fuse_qkv_matmul_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
if
(
pre_layer_norm
)
{
pre_layer_norm_out_node
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
fuse_qkv_matmul_node
->
LinksFrom
({
pre_layer_norm_out_node
,
fuse_qkv_matmul_w_node
})
.
LinksTo
({
fuse_qkv_matmul_out_node
});
}
else
{
fuse_qkv_matmul_node
->
LinksFrom
({
x
,
fuse_qkv_matmul_w_node
})
.
LinksTo
({
fuse_qkv_matmul_out_node
});
}
auto
*
fuse_qkv_ele_add_node
=
pattern
->
NewNode
(
fuse_qkv_ele_add_op_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
fuse_qkv_ele_add_bias_node
=
pattern
->
NewNode
(
fuse_qkv_ele_add_bias_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
fuse_qkv_ele_add_out_node
=
pattern
->
NewNode
(
fuse_qkv_ele_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
fuse_qkv_matmul_out_node
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
fuse_qkv_ele_add_node
->
LinksFrom
({
fuse_qkv_matmul_out_node
,
fuse_qkv_ele_add_bias_node
})
.
LinksTo
({
fuse_qkv_ele_add_out_node
});
auto
*
fuse_qkv_reshape_node
=
pattern
->
NewNode
(
fuse_qkv_reshape_op_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
fuse_qkv_reshape_x_shape_node
=
pattern
->
NewNode
(
fuse_qkv_reshape_x_shape_repr
())
->
assert_is_op_output
(
"reshape2"
,
"XShape"
);
auto
*
fuse_qkv_reshape_out_node
=
pattern
->
NewNode
(
fuse_qkv_reshape_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
fuse_qkv_ele_add_out_node
->
assert_is_op_input
(
"reshape2"
,
"X"
);
fuse_qkv_reshape_node
->
LinksFrom
({
fuse_qkv_ele_add_out_node
})
.
LinksTo
({
fuse_qkv_reshape_x_shape_node
,
fuse_qkv_reshape_out_node
});
auto
*
fuse_qkv_transpose_node
=
pattern
->
NewNode
(
fuse_qkv_transpose_op_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
fuse_qkv_transpose_x_shape_node
=
pattern
->
NewNode
(
fuse_qkv_transpose_x_shape_repr
())
->
assert_is_op_output
(
"transpose2"
,
"XShape"
);
auto
*
fuse_qkv_transpose_out_node
=
pattern
->
NewNode
(
fuse_qkv_transpose_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
fuse_qkv_reshape_out_node
->
assert_is_op_input
(
"transpose2"
,
"X"
);
fuse_qkv_transpose_node
->
LinksFrom
({
fuse_qkv_reshape_out_node
})
.
LinksTo
({
fuse_qkv_transpose_x_shape_node
,
fuse_qkv_transpose_out_node
});
auto
*
fuse_qkv_split_node
=
pattern
->
NewNode
(
fuse_qkv_split_op_repr
())
->
assert_is_op
(
"split"
);
auto
*
fuse_qkv_split_out_q_node
=
pattern
->
NewNode
(
fuse_qkv_split_out_q_repr
())
->
assert_is_op_output
(
"split"
);
auto
*
fuse_qkv_split_out_k_node
=
pattern
->
NewNode
(
fuse_qkv_split_out_k_repr
())
->
assert_is_op_output
(
"split"
);
auto
*
fuse_qkv_split_out_v_node
=
pattern
->
NewNode
(
fuse_qkv_split_out_v_repr
())
->
assert_is_op_output
(
"split"
);
fuse_qkv_transpose_out_node
->
assert_is_op_input
(
"split"
,
"X"
);
fuse_qkv_split_node
->
LinksFrom
({
fuse_qkv_transpose_out_node
})
.
LinksTo
({
fuse_qkv_split_out_q_node
,
fuse_qkv_split_out_k_node
,
fuse_qkv_split_out_v_node
});
// core attention pattern
auto
*
qk_matmul_node
=
pattern
->
NewNode
(
qk_matmul_op_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
qk_matmul_out_node
=
pattern
->
NewNode
(
qk_matmul_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
fuse_qkv_split_out_q_node
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
fuse_qkv_split_out_k_node
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
qk_matmul_node
->
LinksFrom
({
fuse_qkv_split_out_q_node
,
fuse_qkv_split_out_k_node
})
.
LinksTo
({
qk_matmul_out_node
});
auto
*
qk_scale_node
=
pattern
->
NewNode
(
qk_scale_op_repr
())
->
assert_is_op
(
"scale"
);
auto
*
qk_scale_out_node
=
pattern
->
NewNode
(
qk_scale_out_repr
())
->
assert_is_op_output
(
"scale"
);
qk_matmul_out_node
->
assert_is_op_input
(
"scale"
,
"X"
);
qk_scale_node
->
LinksFrom
({
qk_matmul_out_node
}).
LinksTo
({
qk_scale_out_node
});
PDNode
*
add_mask_ele_add_out_node
{
nullptr
};
if
(
has_attn_mask
)
{
auto
*
add_mask_ele_add_node
=
pattern
->
NewNode
(
add_mask_ele_add_op_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
add_mask_ele_add_mask_node
=
pattern
->
NewNode
(
add_mask_ele_add_mask_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
add_mask_ele_add_out_node
=
pattern
->
NewNode
(
add_mask_ele_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
qk_scale_out_node
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
add_mask_ele_add_node
->
LinksFrom
({
qk_scale_out_node
,
add_mask_ele_add_mask_node
})
.
LinksTo
({
add_mask_ele_add_out_node
});
}
auto
*
qk_softmax_node
=
pattern
->
NewNode
(
qk_softmax_op_repr
())
->
assert_is_op
(
"softmax"
);
auto
*
qk_softmax_out_node
=
pattern
->
NewNode
(
qk_softmax_out_repr
())
->
assert_is_op_output
(
"softmax"
);
if
(
has_attn_mask
)
{
add_mask_ele_add_out_node
->
assert_is_op_input
(
"softmax"
,
"X"
);
qk_softmax_node
->
LinksFrom
({
add_mask_ele_add_out_node
})
.
LinksTo
({
qk_softmax_out_node
});
}
else
{
qk_scale_out_node
->
assert_is_op_input
(
"softmax"
,
"X"
);
qk_softmax_node
->
LinksFrom
({
qk_scale_out_node
})
.
LinksTo
({
qk_softmax_out_node
});
}
PDNode
*
attn_dropout_out_node
{
nullptr
};
if
(
do_dropout
)
{
auto
*
attn_dropout_node
=
pattern
->
NewNode
(
attn_dropout_op_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
attn_dropout_mask_node
=
pattern
->
NewNode
(
attn_dropout_mask_repr
())
->
assert_is_op_output
(
"dropout"
,
"Mask"
);
attn_dropout_out_node
=
pattern
->
NewNode
(
attn_dropout_out_repr
())
->
assert_is_op_output
(
"dropout"
);
qk_softmax_out_node
->
assert_is_op_input
(
"dropout"
,
"X"
);
attn_dropout_node
->
LinksFrom
({
qk_softmax_out_node
})
.
LinksTo
({
attn_dropout_mask_node
,
attn_dropout_out_node
});
}
auto
*
qkv_matmul_node
=
pattern
->
NewNode
(
qkv_matmul_op_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
qkv_matmul_out_node
=
pattern
->
NewNode
(
qkv_matmul_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
fuse_qkv_split_out_v_node
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
if
(
do_dropout
)
{
attn_dropout_out_node
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
qkv_matmul_node
->
LinksFrom
({
attn_dropout_out_node
,
fuse_qkv_split_out_v_node
})
.
LinksTo
({
qkv_matmul_out_node
});
}
else
{
qk_softmax_out_node
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
qkv_matmul_node
->
LinksFrom
({
qk_softmax_out_node
,
fuse_qkv_split_out_v_node
})
.
LinksTo
({
qkv_matmul_out_node
});
}
auto
*
qkv_transpose_node
=
pattern
->
NewNode
(
qkv_transpose_op_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
qkv_transpose_x_shape_node
=
pattern
->
NewNode
(
qkv_transpose_x_shape_repr
())
->
assert_is_op_output
(
"transpose2"
,
"XShape"
);
auto
*
qkv_transpose_out_node
=
pattern
->
NewNode
(
qkv_transpose_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
qkv_matmul_out_node
->
assert_is_op_input
(
"transpose2"
,
"X"
);
qkv_transpose_node
->
LinksFrom
({
qkv_matmul_out_node
})
.
LinksTo
({
qkv_transpose_x_shape_node
,
qkv_transpose_out_node
});
auto
*
qkv_reshape_node
=
pattern
->
NewNode
(
qkv_reshape_op_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
qkv_reshape_x_shape_node
=
pattern
->
NewNode
(
qkv_reshape_x_shape_repr
())
->
assert_is_op_output
(
"reshape2"
,
"XShape"
);
auto
*
qkv_reshape_out_node
=
pattern
->
NewNode
(
qkv_reshape_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
qkv_transpose_out_node
->
assert_is_op_input
(
"reshape2"
,
"X"
);
qkv_reshape_node
->
LinksFrom
({
qkv_transpose_out_node
})
.
LinksTo
({
qkv_reshape_x_shape_node
,
qkv_reshape_out_node
});
// out linear pattern
auto
*
out_linear_matmul_node
=
pattern
->
NewNode
(
out_linear_matmul_op_repr
())
->
assert_is_op
(
"matmul_v2"
);
auto
*
out_linear_matmul_w_node
=
pattern
->
NewNode
(
out_linear_matmul_w_repr
())
->
assert_is_op_input
(
"matmul_v2"
,
"Y"
);
auto
*
out_linear_matmul_out_node
=
pattern
->
NewNode
(
out_linear_matmul_out_repr
())
->
assert_is_op_output
(
"matmul_v2"
);
qkv_reshape_out_node
->
assert_is_op_input
(
"matmul_v2"
,
"X"
);
out_linear_matmul_node
->
LinksFrom
({
qkv_reshape_out_node
,
out_linear_matmul_w_node
})
.
LinksTo
({
out_linear_matmul_out_node
});
auto
*
out_linear_ele_add_node
=
pattern
->
NewNode
(
out_linear_ele_add_op_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
out_linear_ele_add_bias_node
=
pattern
->
NewNode
(
out_linear_ele_add_bias_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
out_linear_ele_add_out_node
=
pattern
->
NewNode
(
out_linear_ele_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
out_linear_matmul_out_node
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
out_linear_ele_add_node
->
LinksFrom
({
out_linear_matmul_out_node
,
out_linear_ele_add_bias_node
})
.
LinksTo
({
out_linear_ele_add_out_node
});
auto
*
out_linear_dropout_node
=
pattern
->
NewNode
(
out_linear_dropout_op_repr
())
->
assert_is_op
(
"dropout"
);
auto
*
out_linear_dropout_mask_node
=
pattern
->
NewNode
(
out_linear_dropout_mask_repr
())
->
assert_is_op_output
(
"dropout"
,
"Mask"
);
auto
*
out_linear_dropout_out_node
=
pattern
->
NewNode
(
out_linear_dropout_out_repr
())
->
assert_is_op_output
(
"dropout"
);
out_linear_ele_add_out_node
->
assert_is_op_input
(
"dropout"
,
"X"
);
out_linear_dropout_node
->
LinksFrom
({
out_linear_ele_add_out_node
})
.
LinksTo
({
out_linear_dropout_mask_node
,
out_linear_dropout_out_node
});
if
(
!
add_residual
&&
!
post_layer_norm
)
{
return
out_linear_dropout_out_node
;
}
// add residual
PDNode
*
residual_ele_add_out_node
{
nullptr
};
if
(
add_residual
)
{
// this kind of pattern only support `residual + dropout_out`, since we have
// to fix X and Y
auto
*
residual_ele_add_node
=
pattern
->
NewNode
(
residual_ele_add_op_repr
())
->
assert_is_op
(
"elementwise_add"
);
residual_ele_add_out_node
=
pattern
->
NewNode
(
residual_ele_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
out_linear_dropout_out_node
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
residual_ele_add_node
->
LinksFrom
({
x
,
out_linear_dropout_out_node
})
.
LinksTo
({
residual_ele_add_out_node
});
if
(
!
post_layer_norm
)
{
return
residual_ele_add_out_node
;
}
}
// post layer norm
auto
*
post_layer_norm_node
=
pattern
->
NewNode
(
post_layer_norm_op_repr
())
->
assert_is_op
(
"layer_norm"
);
auto
*
post_layer_norm_scale_node
=
pattern
->
NewNode
(
post_layer_norm_scale_repr
())
->
assert_is_op_input
(
"layer_norm"
,
"Scale"
);
auto
*
post_layer_norm_bias_node
=
pattern
->
NewNode
(
post_layer_norm_bias_repr
())
->
assert_is_op_input
(
"layer_norm"
,
"Bias"
);
auto
*
post_layer_norm_out_node
=
pattern
->
NewNode
(
post_layer_norm_out_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Y"
);
auto
*
post_layer_norm_mean_node
=
pattern
->
NewNode
(
post_layer_norm_mean_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Mean"
);
auto
*
post_layer_norm_variance_node
=
pattern
->
NewNode
(
post_layer_norm_variance_repr
())
->
assert_is_op_output
(
"layer_norm"
,
"Variance"
);
if
(
add_residual
)
{
residual_ele_add_out_node
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
post_layer_norm_node
->
LinksFrom
({
residual_ele_add_out_node
,
post_layer_norm_scale_node
,
post_layer_norm_bias_node
})
.
LinksTo
({
post_layer_norm_out_node
,
post_layer_norm_mean_node
,
post_layer_norm_variance_node
});
}
else
{
out_linear_dropout_out_node
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
post_layer_norm_node
->
LinksFrom
({
out_linear_dropout_out_node
,
post_layer_norm_scale_node
,
post_layer_norm_bias_node
})
.
LinksTo
({
post_layer_norm_out_node
,
post_layer_norm_mean_node
,
post_layer_norm_variance_node
});
}
return
post_layer_norm_out_node
;
}
PDNode
*
FusedAttentionGradPattern
::
operator
()(
PDNode
*
x
,
bool
pre_layer_norm
,
bool
post_layer_norm
,
bool
has_attn_mask
,
bool
do_dropout
,
bool
add_residual
)
{
// TODO(Yuang Liu): finish the backward pattern
return
nullptr
;
}
}
// namespace patterns
void
FusedAttentionsPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
graph
=
PreMaskDropResPostFwd
(
graph
);
graph
=
PreMaskDropResPostBwd
(
graph
);
}
ir
::
Graph
*
FusedAttentionsPass
::
PreMaskDropResPostFwd
(
Graph
*
graph
)
const
{
GraphPatternDetector
gpd
;
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
patterns
::
PDNodeName
(
name_scope_
,
"x"
))
->
AsInput
()
->
assert_is_op_input
(
"layer_norm"
,
"X"
);
patterns
::
FusedAttentionPattern
fused_attention_pattern
(
gpd
.
mutable_pattern
(),
"fused_attention_pattern"
);
fused_attention_pattern
(
x
,
/* pre_layer_norm */
true
,
/* post_layer_norm */
true
,
/* has_attn_mask */
true
,
/* do_dropout */
true
,
/* add_residual */
true
);
int
found_fused_attention
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
3
)
<<
"handle FusedMultiHeadAttention pass's fusion"
;
GET_IR_NODE_FROM_SUBGRAPH
(
pre_layer_norm_op_node
,
pre_layer_norm_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_matmul_op_node
,
fuse_qkv_matmul_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_ele_add_op_node
,
fuse_qkv_ele_add_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_reshape_op_node
,
fuse_qkv_reshape_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_transpose_op_node
,
fuse_qkv_transpose_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fuse_qkv_split_op_node
,
fuse_qkv_split_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qk_matmul_op_node
,
qk_matmul_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qk_scale_op_node
,
qk_scale_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
add_mask_ele_add_op_node
,
add_mask_ele_add_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qk_softmax_op_node
,
qk_softmax_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
attn_dropout_op_node
,
attn_dropout_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qkv_matmul_op_node
,
qkv_matmul_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qkv_transpose_op_node
,
qkv_transpose_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
qkv_reshape_op_node
,
qkv_reshape_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_matmul_op_node
,
out_linear_matmul_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_ele_add_op_node
,
out_linear_ele_add_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
out_linear_dropout_op_node
,
out_linear_dropout_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
residual_ele_add_op_node
,
residual_ele_add_op
,
fused_attention_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
post_layer_norm_op_node
,
post_layer_norm_op
,
fused_attention_pattern
);
// TODO(Yuang Liu): finish the handler
GraphSafeRemoveNodes
(
g
,
{
pre_layer_norm_op_node
,
fuse_qkv_matmul_op_node
,
fuse_qkv_ele_add_op_node
,
fuse_qkv_reshape_op_node
,
fuse_qkv_transpose_op_node
,
fuse_qkv_split_op_node
,
qk_matmul_op_node
,
qk_scale_op_node
,
add_mask_ele_add_op_node
,
qk_softmax_op_node
,
attn_dropout_op_node
,
qkv_matmul_op_node
,
qkv_transpose_op_node
,
qkv_reshape_op_node
,
out_linear_matmul_op_node
,
out_linear_ele_add_op_node
,
out_linear_dropout_op_node
,
residual_ele_add_op_node
,
post_layer_norm_op_node
});
found_fused_attention
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_fused_attention
);
return
graph
;
}
ir
::
Graph
*
FusedAttentionsPass
::
PreMaskDropResPostBwd
(
Graph
*
graph
)
const
{
// TODO(Yuang Liu): finish the pass
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
fused_attention_pass
,
paddle
::
framework
::
ir
::
FusedAttentionsPass
);
paddle/fluid/framework/ir/fused_attention_pass.h
0 → 100644
浏览文件 @
b0ece266
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
// Declare patterns for multi head attention.
// Can detect:
// 1. Pre layer norm, post layer norm or sandwich layer norm.
// 2. Add attn mask for qk product before the softmax or not.
// 3. Do attn dropout or not.
// 4. Add residual to the out linear result or not.
struct
FusedAttentionPattern
:
public
PatternBase
{
FusedAttentionPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fused_attention_pattern"
)
{}
PDNode
*
operator
()(
PDNode
*
x
,
bool
pre_layer_norm
,
// do pre ln or not
bool
post_layer_norm
,
// do post ln or not
bool
has_attn_mask
,
// add attn mask to qk or not
bool
do_dropout
,
// dropout the softmax(qk) or not
bool
add_residual
);
// add residual to out linear or not
// pre layer norm
PATTERN_DECL_NODE
(
pre_layer_norm_op
);
PATTERN_DECL_NODE
(
pre_layer_norm_scale
);
PATTERN_DECL_NODE
(
pre_layer_norm_bias
);
PATTERN_DECL_NODE
(
pre_layer_norm_out
);
PATTERN_DECL_NODE
(
pre_layer_norm_mean
);
PATTERN_DECL_NODE
(
pre_layer_norm_variance
);
// fuse qkv projection
PATTERN_DECL_NODE
(
fuse_qkv_matmul_op
);
PATTERN_DECL_NODE
(
fuse_qkv_matmul_w
);
PATTERN_DECL_NODE
(
fuse_qkv_matmul_out
);
PATTERN_DECL_NODE
(
fuse_qkv_ele_add_op
);
PATTERN_DECL_NODE
(
fuse_qkv_ele_add_bias
);
PATTERN_DECL_NODE
(
fuse_qkv_ele_add_out
);
PATTERN_DECL_NODE
(
fuse_qkv_reshape_op
);
PATTERN_DECL_NODE
(
fuse_qkv_reshape_out
);
PATTERN_DECL_NODE
(
fuse_qkv_reshape_x_shape
);
PATTERN_DECL_NODE
(
fuse_qkv_transpose_op
);
PATTERN_DECL_NODE
(
fuse_qkv_transpose_out
);
PATTERN_DECL_NODE
(
fuse_qkv_transpose_x_shape
);
PATTERN_DECL_NODE
(
fuse_qkv_split_op
);
PATTERN_DECL_NODE
(
fuse_qkv_split_out_q
);
// q
PATTERN_DECL_NODE
(
fuse_qkv_split_out_k
);
// k
PATTERN_DECL_NODE
(
fuse_qkv_split_out_v
);
// v
// core attention
PATTERN_DECL_NODE
(
qk_matmul_op
);
PATTERN_DECL_NODE
(
qk_matmul_out
);
PATTERN_DECL_NODE
(
qk_scale_op
);
PATTERN_DECL_NODE
(
qk_scale_out
);
PATTERN_DECL_NODE
(
add_mask_ele_add_op
);
PATTERN_DECL_NODE
(
add_mask_ele_add_mask
);
PATTERN_DECL_NODE
(
add_mask_ele_add_out
);
PATTERN_DECL_NODE
(
qk_softmax_op
);
PATTERN_DECL_NODE
(
qk_softmax_out
);
PATTERN_DECL_NODE
(
attn_dropout_op
);
PATTERN_DECL_NODE
(
attn_dropout_out
);
PATTERN_DECL_NODE
(
attn_dropout_mask
);
PATTERN_DECL_NODE
(
qkv_matmul_op
);
PATTERN_DECL_NODE
(
qkv_matmul_out
);
PATTERN_DECL_NODE
(
qkv_transpose_op
);
PATTERN_DECL_NODE
(
qkv_transpose_out
);
PATTERN_DECL_NODE
(
qkv_transpose_x_shape
);
PATTERN_DECL_NODE
(
qkv_reshape_op
);
PATTERN_DECL_NODE
(
qkv_reshape_out
);
PATTERN_DECL_NODE
(
qkv_reshape_x_shape
);
// out linear
PATTERN_DECL_NODE
(
out_linear_matmul_op
);
PATTERN_DECL_NODE
(
out_linear_matmul_w
);
PATTERN_DECL_NODE
(
out_linear_matmul_out
);
PATTERN_DECL_NODE
(
out_linear_ele_add_op
);
PATTERN_DECL_NODE
(
out_linear_ele_add_bias
);
PATTERN_DECL_NODE
(
out_linear_ele_add_out
);
PATTERN_DECL_NODE
(
out_linear_dropout_op
);
PATTERN_DECL_NODE
(
out_linear_dropout_out
);
PATTERN_DECL_NODE
(
out_linear_dropout_mask
);
// residual
PATTERN_DECL_NODE
(
residual_ele_add_op
);
PATTERN_DECL_NODE
(
residual_ele_add_out
);
// post layer norm
PATTERN_DECL_NODE
(
post_layer_norm_op
);
PATTERN_DECL_NODE
(
post_layer_norm_scale
);
PATTERN_DECL_NODE
(
post_layer_norm_bias
);
PATTERN_DECL_NODE
(
post_layer_norm_out
);
PATTERN_DECL_NODE
(
post_layer_norm_mean
);
PATTERN_DECL_NODE
(
post_layer_norm_variance
);
};
// Declare the grad pattern for multi head attention
struct
FusedAttentionGradPattern
:
public
PatternBase
{
FusedAttentionGradPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fused_attention_pattern"
)
{}
PDNode
*
operator
()(
PDNode
*
x
,
bool
pre_layer_norm
,
// pre ln
bool
post_layer_norm
,
// post ln
bool
has_attn_mask
,
// add attn mask to qk or not
bool
do_dropout
,
// dropout the softmax(qk) or not
bool
add_residual
);
// add residual to out linear or not
// TODO(Yuang Liu): add backward pattern
};
}
// namespace patterns
class
FusedAttentionsPass
:
public
FusePassBase
{
public:
virtual
~
FusedAttentionsPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"fused_attention_pass"
};
private:
// The name rule for the helper function.
// The function name will contain at most five parts in order:
// 1. Do pre layer norm? [Pre]
// 2. Add mask in the core attention part? [Mask]
// 3. Do dropout in the core attention part? [Drop]
// 4. Add residual? [Res]
// 5. Do post layer norm? [Post]
// 6. Forward or Backward? [Fwd/Bwd]
// If true, the function name will have an abbreviation part.
// If false, the function name won't contain an abbreviation for it.
ir
::
Graph
*
PreMaskDropResPostFwd
(
Graph
*
graph
)
const
;
ir
::
Graph
*
PreMaskDropResPostBwd
(
Graph
*
graph
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/pybind/parallel_executor.cc
浏览文件 @
b0ece266
...
@@ -714,6 +714,32 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT
...
@@ -714,6 +714,32 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT
build_strategy = static.BuildStrategy()
build_strategy = static.BuildStrategy()
build_strategy.fuse_gemm_epilogue = True
build_strategy.fuse_gemm_epilogue = True
)DOC"
)
)DOC"
)
.
def_property
(
"fused_attention"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
fused_attention_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
PADDLE_ENFORCE_NE
(
self
.
IsFinalized
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"BuildStrategy has been finlaized, cannot be "
"configured again."
));
self
.
fused_attention_
=
b
;
},
R"DOC((bool, optional): fused_attention indicate whether
to fuse the whole multi head attention part with one op,
it may make the execution faster. Default is False.
Examples:
.. code-block:: python
import paddle
import paddle.static as static
paddle.enable_static()
build_strategy = static.BuildStrategy()
build_strategy.fused_attention = True
)DOC"
)
.
def_property
(
.
def_property
(
"fuse_bn_act_ops"
,
"fuse_bn_act_ops"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
fuse_bn_act_ops_
;
},
[](
const
BuildStrategy
&
self
)
{
return
self
.
fuse_bn_act_ops_
;
},
...
...
python/paddle/distributed/passes/cpp_pass.py
浏览文件 @
b0ece266
...
@@ -71,6 +71,19 @@ class FuseReluDepthwiseConvPass(CPPPassWrapper):
...
@@ -71,6 +71,19 @@ class FuseReluDepthwiseConvPass(CPPPassWrapper):
return
PassType
.
FUSION_OPT
return
PassType
.
FUSION_OPT
@
register_pass
(
"fused_attention"
)
class
FusedAttentionPass
(
CPPPassWrapper
):
def
__init__
(
self
):
super
().
__init__
()
@
property
def
cpp_name
(
self
):
return
"fused_attention_pass"
def
_type
(
self
):
return
PassType
.
FUSION_OPT
@
register_pass
(
"fuse_gemm_epilogue"
)
@
register_pass
(
"fuse_gemm_epilogue"
)
class
FuseGemmEpiloguePass
(
CPPPassWrapper
):
class
FuseGemmEpiloguePass
(
CPPPassWrapper
):
def
__init__
(
self
):
def
__init__
(
self
):
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
b0ece266
...
@@ -76,6 +76,7 @@ if(NOT WITH_GPU)
...
@@ -76,6 +76,7 @@ if(NOT WITH_GPU)
list
(
REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer
)
list
(
REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer
)
list
(
REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op
)
list
(
REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op
)
list
(
REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api
)
list
(
REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api
)
list
(
REMOVE_ITEM TEST_OPS test_fused_attention_pass
)
endif
()
endif
()
list
(
REMOVE_ITEM TEST_OPS test_fused_ec_moe_op
)
list
(
REMOVE_ITEM TEST_OPS test_fused_ec_moe_op
)
...
...
python/paddle/fluid/tests/unittests/test_fused_attention_pass.py
0 → 100644
浏览文件 @
b0ece266
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle
import
paddle.fluid.core
as
core
import
paddle.nn.functional
as
F
from
paddle.distributed.passes
import
PassManager
,
new_pass
paddle
.
enable_static
()
class
MultiHeadAttention
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
embed_dim
,
num_heads
,
add_residual
=
True
,
pre_ln
=
True
,
post_ln
=
False
,
attn_dropout
=
True
,
):
super
(
MultiHeadAttention
,
self
).
__init__
()
self
.
embed_dim
=
embed_dim
self
.
kdim
=
embed_dim
self
.
vdim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
add_residual
=
add_residual
self
.
pre_ln
=
pre_ln
self
.
post_ln
=
post_ln
self
.
attn_dropout
=
attn_dropout
self
.
head_dim
=
embed_dim
//
num_heads
assert
(
self
.
head_dim
*
num_heads
==
self
.
embed_dim
),
"embed_dim must be divisible by num_heads"
self
.
norm1
=
paddle
.
nn
.
LayerNorm
(
embed_dim
,
epsilon
=
1e-5
)
self
.
norm2
=
paddle
.
nn
.
LayerNorm
(
embed_dim
,
epsilon
=
1e-5
)
self
.
qkv_proj
=
paddle
.
nn
.
Linear
(
embed_dim
,
3
*
embed_dim
)
self
.
out_proj
=
paddle
.
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
dropout
=
paddle
.
nn
.
Dropout
(
0.1
,
mode
=
"upscale_in_train"
)
def
forward
(
self
,
x
,
attn_mask
=
None
):
residual
=
x
if
self
.
pre_ln
:
# pre layer norm
x
=
self
.
norm1
(
x
)
# compute qkv
qkv
=
self
.
qkv_proj
(
x
)
qkv
=
paddle
.
reshape
(
qkv
,
[
0
,
0
,
self
.
num_heads
,
3
*
self
.
head_dim
])
qkv
=
paddle
.
transpose
(
qkv
,
[
0
,
2
,
1
,
3
])
q
,
k
,
v
=
paddle
.
split
(
qkv
,
num_or_sections
=
3
,
axis
=-
1
)
# compute core attention
product
=
paddle
.
matmul
(
x
=
q
,
y
=
k
,
transpose_y
=
True
)
product
=
paddle
.
scale
(
product
,
scale
=
self
.
head_dim
**-
0.5
)
if
attn_mask
is
not
None
:
product
=
product
+
attn_mask
weights
=
F
.
softmax
(
product
)
if
self
.
attn_dropout
:
weights
=
F
.
dropout
(
weights
,
0.1
,
training
=
self
.
training
,
mode
=
"upscale_in_train"
)
out
=
paddle
.
matmul
(
weights
,
v
)
out
=
paddle
.
transpose
(
out
,
perm
=
[
0
,
2
,
1
,
3
])
out
=
paddle
.
reshape
(
x
=
out
,
shape
=
[
0
,
0
,
out
.
shape
[
2
]
*
out
.
shape
[
3
]])
# project to output
out
=
self
.
out_proj
(
out
)
out
=
self
.
dropout
(
out
)
if
self
.
add_residual
:
out
=
residual
+
out
if
self
.
post_ln
:
# post layer norm
out
=
self
.
norm2
(
out
)
return
out
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestFusedAttentionPass
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
add_residual
=
True
self
.
pre_ln
=
True
self
.
post_ln
=
True
self
.
attn_dropout
=
True
self
.
add_mask
=
True
def
test_pass
(
self
):
batch_size
=
2
seq_len
=
1024
hidden_size
=
768
num_heads
=
12
x_data
=
np
.
random
.
rand
(
batch_size
,
seq_len
,
hidden_size
).
astype
(
'float32'
)
mask_data
=
np
.
random
.
rand
(
batch_size
,
num_heads
,
seq_len
,
seq_len
).
astype
(
'float32'
)
main_prog
=
paddle
.
static
.
Program
()
startup_prog
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main_prog
,
startup_prog
):
data
=
paddle
.
static
.
data
(
name
=
"x"
,
shape
=
[
-
1
,
seq_len
,
hidden_size
],
dtype
=
'float32'
,
)
if
self
.
add_mask
:
attn_mask
=
paddle
.
static
.
data
(
name
=
"attn_mask"
,
shape
=
[
-
1
,
num_heads
,
seq_len
,
seq_len
],
dtype
=
'float32'
,
)
else
:
attn_mask
=
None
multi_head_attn
=
MultiHeadAttention
(
hidden_size
,
num_heads
,
add_residual
=
self
.
add_residual
,
pre_ln
=
self
.
pre_ln
,
post_ln
=
self
.
post_ln
,
attn_dropout
=
self
.
attn_dropout
,
)
out
=
multi_head_attn
(
data
,
attn_mask
)
loss
=
paddle
.
mean
(
out
)
sgd_optimizer
=
paddle
.
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
sgd_optimizer
.
minimize
(
loss
)
pass_manager
=
PassManager
([
new_pass
(
"fused_attention"
)])
pass_manager
.
apply
([
main_prog
],
[
startup_prog
])
ops
=
main_prog
.
global_block
().
ops
assert
ops
[
0
].
type
==
'reduce_mean'
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
0
)
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录