Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
643c94e4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
643c94e4
编写于
8月 05, 2022
作者:
C
carryyu
提交者:
GitHub
8月 05, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enhance fused_multi_transformer_op(post_layer_norm) (#44789)
* add fused_multi_transformer post_layer_norm * add test post_layer_norm
上级
bdce552b
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
134 addition
and
23 deletion
+134
-23
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
+79
-23
python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py
.../fluid/tests/unittests/test_fused_multi_transformer_op.py
+55
-0
未找到文件。
paddle/fluid/operators/fused/fused_multi_transformer_op.cu
浏览文件 @
643c94e4
...
@@ -1279,9 +1279,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1279,9 +1279,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
auto
ffn_ln_scales
=
ctx
.
MultiInput
<
Tensor
>
(
"FFNLnScale"
);
auto
ffn_ln_scales
=
ctx
.
MultiInput
<
Tensor
>
(
"FFNLnScale"
);
auto
ffn_ln_biases
=
ctx
.
MultiInput
<
Tensor
>
(
"FFNLnBias"
);
auto
ffn_ln_biases
=
ctx
.
MultiInput
<
Tensor
>
(
"FFNLnBias"
);
Tensor
bias_dropout_residual_out
,
dropout_mask_out
;
Tensor
bias_dropout_residual_out
,
dropout_mask_out
;
auto
*
bias_dropout_residual_out_data
=
T
*
bias_dropout_residual_out_data
=
nullptr
;
bias_dropout_residual_out
.
mutable_data
<
T
>
({
bsz
,
seq_len
,
dim_embed
},
if
(
pre_layer_norm
)
{
place
);
bias_dropout_residual_out_data
=
bias_dropout_residual_out
.
mutable_data
<
T
>
({
bsz
,
seq_len
,
dim_embed
},
place
);
}
auto
*
dropout_mask_out_data
=
dropout_mask_out
.
mutable_data
<
uint8_t
>
(
auto
*
dropout_mask_out_data
=
dropout_mask_out
.
mutable_data
<
uint8_t
>
(
{
bsz
,
seq_len
,
dim_embed
},
place
);
{
bsz
,
seq_len
,
dim_embed
},
place
);
...
@@ -1333,14 +1336,19 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1333,14 +1336,19 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
// step1: buf1 --> buf0
// step1: buf1 --> buf0
// step2: buf0 --> buf1
// step2: buf0 --> buf1
int
layers
=
qkv_weights
.
size
();
int
layers
=
qkv_weights
.
size
();
if
(
layers
&
1
)
{
if
(
pre_layer_norm
)
{
// odd, set buf1 as out
if
(
layers
&
1
)
{
// odd, set buf1 as out
buf0
=
&
tmp_out
;
buf1
=
out
;
}
else
{
// even, set buf0 as out
buf0
=
out
;
buf1
=
&
tmp_out
;
}
}
else
{
buf0
=
&
tmp_out
;
buf0
=
&
tmp_out
;
buf1
=
out
;
buf1
=
out
;
}
else
{
// even, set buf0 as out
buf0
=
out
;
buf1
=
&
tmp_out
;
}
}
for
(
int
i
=
0
;
i
<
layers
;
++
i
)
{
for
(
int
i
=
0
;
i
<
layers
;
++
i
)
{
...
@@ -1355,9 +1363,6 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1355,9 +1363,6 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
buf1
->
data
<
T
>
(),
buf1
->
data
<
T
>
(),
ln_mean_data
,
ln_mean_data
,
ln_var_data
);
ln_var_data
);
}
else
if
(
!
pre_layer_norm
)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unimplemented post_layer_norm for now."
));
}
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step1"
;
VLOG
(
0
)
<<
"step1"
;
...
@@ -1367,8 +1372,13 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1367,8 +1372,13 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
const
Tensor
*
qkv_bias
=
qkv_biases
.
size
()
>
0
?
qkv_biases
[
i
]
:
nullptr
;
const
Tensor
*
qkv_bias
=
qkv_biases
.
size
()
>
0
?
qkv_biases
[
i
]
:
nullptr
;
// NOTE: in decoder stage, bias is fused in fmha
// NOTE: in decoder stage, bias is fused in fmha
const
Tensor
*
bias
=
time_step
?
nullptr
:
qkv_bias
;
const
Tensor
*
bias
=
time_step
?
nullptr
:
qkv_bias
;
qkv_compute
.
ComputeForward
(
if
(
!
pre_layer_norm
&&
i
==
0
)
{
qkv_weights
[
i
],
buf1
,
bias
,
&
qkv_out
,
&
qkv_out
);
qkv_compute
.
ComputeForward
(
qkv_weights
[
i
],
input_x
,
bias
,
&
qkv_out
,
&
qkv_out
);
}
else
{
qkv_compute
.
ComputeForward
(
qkv_weights
[
i
],
buf1
,
bias
,
&
qkv_out
,
&
qkv_out
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step2"
;
VLOG
(
0
)
<<
"step2"
;
#endif
#endif
...
@@ -1451,10 +1461,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1451,10 +1461,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
VLOG
(
0
)
<<
"step3"
;
VLOG
(
0
)
<<
"step3"
;
#endif
#endif
// step4. out_linear
if
(
pre_layer_norm
)
{
out_linear_compute
.
ComputeForward
(
out_linear_compute
.
ComputeForward
(
out_linear_weights
[
i
],
&
fmha_out
,
nullptr
,
buf1
,
nullptr
);
out_linear_weights
[
i
],
&
fmha_out
,
nullptr
,
buf1
,
nullptr
);
AllReduce
<
T
>
(
*
buf1
,
ring_id
,
dev_ctx
);
AllReduce
<
T
>
(
*
buf1
,
ring_id
,
dev_ctx
);
}
else
{
out_linear_compute
.
ComputeForward
(
out_linear_weights
[
i
],
&
fmha_out
,
nullptr
,
buf0
,
nullptr
);
AllReduce
<
T
>
(
*
buf0
,
ring_id
,
dev_ctx
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step4"
;
VLOG
(
0
)
<<
"step4"
;
#endif
#endif
...
@@ -1479,6 +1494,22 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1479,6 +1494,22 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
ln_mean_data
,
ln_mean_data
,
ln_var_data
);
ln_var_data
);
}
else
{
}
else
{
auto
*
ln_scale_data
=
ln_scales
[
i
]
->
data
<
U
>
();
auto
*
ln_bias_data
=
ln_biases
[
i
]
->
data
<
U
>
();
auto
*
out_linear_bias_data
=
out_linear_biases
[
i
]
->
data
<
T
>
();
auto
*
residual_data
=
(
i
==
0
?
x_data
:
buf1
->
data
<
T
>
());
fused_dropout_layernorm_helper
.
LayernormResidualDropoutBias
(
dev_ctx
,
buf0
->
data
<
T
>
(),
residual_data
,
out_linear_bias_data
,
ln_scale_data
,
ln_bias_data
,
buf0
->
data
<
T
>
(),
dropout_mask_out_data
,
buf1
->
data
<
T
>
(),
ln_mean_data
,
ln_var_data
);
}
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step5"
;
VLOG
(
0
)
<<
"step5"
;
...
@@ -1504,13 +1535,22 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1504,13 +1535,22 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
#endif
#endif
// step8. ffn matmul2
// step8. ffn matmul2
ffn2_linear_compute
.
ComputeForward
(
if
(
pre_layer_norm
)
{
ffn2_weights
[
i
],
&
ffn1_dropout_out
,
nullptr
,
buf1
,
nullptr
);
ffn2_linear_compute
.
ComputeForward
(
ffn2_weights
[
i
],
&
ffn1_dropout_out
,
nullptr
,
buf1
,
nullptr
);
}
else
{
ffn2_linear_compute
.
ComputeForward
(
ffn2_weights
[
i
],
&
ffn1_dropout_out
,
nullptr
,
buf0
,
nullptr
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step8.0"
;
VLOG
(
0
)
<<
"step8.0"
;
#endif
#endif
AllReduce
<
T
>
(
*
buf1
,
ring_id
,
dev_ctx
);
if
(
pre_layer_norm
)
{
AllReduce
<
T
>
(
*
buf1
,
ring_id
,
dev_ctx
);
}
else
{
AllReduce
<
T
>
(
*
buf0
,
ring_id
,
dev_ctx
);
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step8.1"
;
VLOG
(
0
)
<<
"step8.1"
;
#endif
#endif
...
@@ -1543,12 +1583,28 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
...
@@ -1543,12 +1583,28 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel<T> {
dropout_mask_out_data
);
dropout_mask_out_data
);
}
}
}
else
{
}
else
{
auto
*
ln_scale_data
=
ffn_ln_scales
[
i
]
->
data
<
U
>
();
auto
*
ln_bias_data
=
ffn_ln_biases
[
i
]
->
data
<
U
>
();
ffn2_fused_dropout_helper
.
LayernormResidualDropoutBias
(
dev_ctx
,
buf0
->
data
<
T
>
(),
buf1
->
data
<
T
>
(),
ffn2_biases
[
i
]
->
data
<
T
>
(),
ln_scale_data
,
ln_bias_data
,
buf0
->
data
<
T
>
(),
dropout_mask_out_data
,
buf1
->
data
<
T
>
(),
ln_mean_data
,
ln_var_data
);
}
}
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG
(
0
)
<<
"step9"
;
VLOG
(
0
)
<<
"step9"
;
#endif
#endif
x_data
=
buf1
->
data
<
T
>
();
if
(
pre_layer_norm
)
{
std
::
swap
(
buf0
,
buf1
);
x_data
=
buf1
->
data
<
T
>
();
std
::
swap
(
buf0
,
buf1
);
}
}
}
}
}
};
};
...
...
python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py
浏览文件 @
643c94e4
...
@@ -548,5 +548,60 @@ class TestFusedMultiTransformerOpGenCacheKVFp16(TestFusedMultiTransformerOp):
...
@@ -548,5 +548,60 @@ class TestFusedMultiTransformerOpGenCacheKVFp16(TestFusedMultiTransformerOp):
self
.
layers
=
3
# odd layers
self
.
layers
=
3
# odd layers
class
TestFusedMultiTransformerOpPostLayerNormFp16
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
super
().
config
()
self
.
x_type
=
np
.
float16
self
.
layers
=
3
# odd layers
self
.
pre_layer_norm
=
False
class
TestFusedMultiTransformerOpCacheKVPostLayerNorm
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
query_length
=
1
self
.
key_length
,
self
.
value_length
=
1
,
1
self
.
layers
=
3
# odd layers
self
.
pre_layer_norm
=
False
class
TestFusedMultiTransformerOpCacheKVPostLayerNormFp16
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
query_length
=
1
self
.
key_length
,
self
.
value_length
=
1
,
1
self
.
x_type
=
np
.
float16
self
.
pre_layer_norm
=
False
class
TestFusedMultiTransformerOpGenCacheKVPostLayerNorm
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
gen_cache_kv
=
True
self
.
pre_layer_norm
=
False
class
TestFusedMultiTransformerOpGenCacheKVPostLayerNormFp16
(
TestFusedMultiTransformerOp
):
def
config
(
self
):
super
().
config
()
self
.
has_cache_kv
=
True
self
.
gen_cache_kv
=
True
self
.
x_type
=
np
.
float16
self
.
layers
=
3
# odd layers
self
.
pre_layer_norm
=
False
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录