Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
643c94e4
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
643c94e4
编写于
8月 05, 2022
作者:
C
carryyu
提交者:
GitHub
8月 05, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enhance fused_multi_transformer_op(post_layer_norm) (#44789)
* add fused_multi_transformer post_layer_norm * add test post_layer_norm
上级
bdce552b
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
134 addition
and
23 deletion
+134
-23
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
+79
-23
python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py
.../fluid/tests/unittests/test_fused_multi_transformer_op.py
+55
-0
未找到文件。
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
浏览文件 @
643c94e4
...
...
@@ -1279,9 +1279,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto
ffn_ln_scales
=
ctx
.
MultiInput
<
Tensor
>
(
"FFNLnScale"
);
auto
ffn_ln_biases
=
ctx
.
MultiInput
<
Tensor
>
(
"FFNLnBias"
);
Tensor
bias_dropout_residual_out
,
dropout_mask_out
;
auto
*
bias_dropout_residual_out_data
=
T
*
bias_dropout_residual_out_data
=
nullptr
;
if
(
pre_layer_norm
)
{
bias_dropout_residual_out_data
=
bias_dropout_residual_out
.
mutable_data
<
T
>
({
bsz
,
seq_len
,
dim_embed
},
place
);
}
auto
*
dropout_mask_out_data
=
dropout_mask_out
.
mutable_data
<
uint8_t
>
(
{
bsz
,
seq_len
,
dim_embed
},
place
);
...
...
@@ -1333,6 +1336,7 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// step1: buf1 --> buf0
// step2: buf0 --> buf1
int
layers
=
qkv_weights
.
size
();
if
(
pre_layer_norm
)
{
if
(
layers
&
1
)
{
// odd, set buf1 as out
buf0
=
&
tmp_out
;
...
...
@@ -1342,6 +1346,10 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
buf0
=
out
;
buf1
=
&
tmp_out
;
}
}
else
{
buf0
=
&
tmp_out
;
buf1
=
out
;
}
for
(
int
i
=
0
;
i
<
layers
;
++
i
)
{
// step1. layer_norm
...
...
@@ -1355,9 +1363,6 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
buf1
->
data
<
T
>
(),
ln_mean_data
,
ln_var_data
);
}
else
if
(
!
pre_layer_norm
)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unimplemented post_layer_norm for now."
));
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step1"
;
...
...
@@ -1367,8 +1372,13 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
const
Tensor
*
qkv_bias
=
qkv_biases
.
size
()
>
0
?
qkv_biases
[
i
]
:
nullptr
;
// NOTE: in decoder stage, bias is fused in fmha
const
Tensor
*
bias
=
time_step
?
nullptr
:
qkv_bias
;
if
(
!
pre_layer_norm
&&
i
==
0
)
{
qkv_compute
.
ComputeForward
(
qkv_weights
[
i
],
input_x
,
bias
,
&
qkv_out
,
&
qkv_out
);
}
else
{
qkv_compute
.
ComputeForward
(
qkv_weights
[
i
],
buf1
,
bias
,
&
qkv_out
,
&
qkv_out
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step2"
;
#endif
...
...
@@ -1451,10 +1461,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
VLOG
(
0
)
<<
"step3"
;
#endif
// step4. out_linear
if
(
pre_layer_norm
)
{
out_linear_compute
.
ComputeForward
(
out_linear_weights
[
i
],
&
fmha_out
,
nullptr
,
buf1
,
nullptr
);
AllReduce
<
T
>
(
*
buf1
,
ring_id
,
dev_ctx
);
}
else
{
out_linear_compute
.
ComputeForward
(
out_linear_weights
[
i
],
&
fmha_out
,
nullptr
,
buf0
,
nullptr
);
AllReduce
<
T
>
(
*
buf0
,
ring_id
,
dev_ctx
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step4"
;
#endif
...
...
@@ -1479,6 +1494,22 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
ln_mean_data
,
ln_var_data
);
}
else
{
auto
*
ln_scale_data
=
ln_scales
[
i
]
->
data
<
U
>
();
auto
*
ln_bias_data
=
ln_biases
[
i
]
->
data
<
U
>
();
auto
*
out_linear_bias_data
=
out_linear_biases
[
i
]
->
data
<
T
>
();
auto
*
residual_data
=
(
i
==
0
?
x_data
:
buf1
->
data
<
T
>
());
fused_dropout_layernorm_helper
.
LayernormResidualDropoutBias
(
dev_ctx
,
buf0
->
data
<
T
>
(),
residual_data
,
out_linear_bias_data
,
ln_scale_data
,
ln_bias_data
,
buf0
->
data
<
T
>
(),
dropout_mask_out_data
,
buf1
->
data
<
T
>
(),
ln_mean_data
,
ln_var_data
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step5"
;
...
...
@@ -1504,13 +1535,22 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
#endif
// step8. ffn matmul2
if
(
pre_layer_norm
)
{
ffn2_linear_compute
.
ComputeForward
(
ffn2_weights
[
i
],
&
ffn1_dropout_out
,
nullptr
,
buf1
,
nullptr
);
}
else
{
ffn2_linear_compute
.
ComputeForward
(
ffn2_weights
[
i
],
&
ffn1_dropout_out
,
nullptr
,
buf0
,
nullptr
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step8.0"
;
#endif
if
(
pre_layer_norm
)
{
AllReduce
<
T
>
(
*
buf1
,
ring_id
,
dev_ctx
);
}
else
{
AllReduce
<
T
>
(
*
buf0
,
ring_id
,
dev_ctx
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step8.1"
;
#endif
...
...
@@ -1543,14 +1583,30 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
dropout_mask_out_data
);
}
}
else
{
auto
*
ln_scale_data
=
ffn_ln_scales
[
i
]
->
data
<
U
>
();
auto
*
ln_bias_data
=
ffn_ln_biases
[
i
]
->
data
<
U
>
();
ffn2_fused_dropout_helper
.
LayernormResidualDropoutBias
(
dev_ctx
,
buf0
->
data
<
T
>
(),
buf1
->
data
<
T
>
(),
ffn2_biases
[
i
]
->
data
<
T
>
(),
ln_scale_data
,
ln_bias_data
,
buf0
->
data
<
T
>
(),
dropout_mask_out_data
,
buf1
->
data
<
T
>
(),
ln_mean_data
,
ln_var_data
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step9"
;
#endif
if
(
pre_layer_norm
)
{
x_data
=
buf1
->
data
<
T
>
();
std
::
swap
(
buf0
,
buf1
);
}
}
}
};
}
// namespace operators
...
...
python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py
浏览文件 @
643c94e4
...
...
@@ -548,5 +548,60 @@ class TestFusedMultiTransformerOpGenCacheKVFp16(TestFusedMultiTransformerOp):
self
.
layers
=
3
# odd layers
class
TestFusedMultiTransformerOpPostLayerNormFp16
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
super
().
config
()
self
.
x_type
=
np
.
float16
self
.
layers
=
3
# odd layers
self
.
pre_layer_norm
=
False
class
TestFusedMultiTransformerOpCacheKVPostLayerNorm
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
query_length
=
1
self
.
key_length
,
self
.
value_length
=
1
,
1
self
.
layers
=
3
# odd layers
self
.
pre_layer_norm
=
False
class
TestFusedMultiTransformerOpCacheKVPostLayerNormFp16
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
query_length
=
1
self
.
key_length
,
self
.
value_length
=
1
,
1
self
.
x_type
=
np
.
float16
self
.
pre_layer_norm
=
False
class
TestFusedMultiTransformerOpGenCacheKVPostLayerNorm
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
gen_cache_kv
=
True
self
.
pre_layer_norm
=
False
class
TestFusedMultiTransformerOpGenCacheKVPostLayerNormFp16
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
gen_cache_kv
=
True
self
.
x_type
=
np
.
float16
self
.
layers
=
3
# odd layers
self
.
pre_layer_norm
=
False
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录