Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
b4a93884
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b4a93884
编写于
6月 10, 2022
作者:
L
limingshu
提交者:
GitHub
6月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize bwd layer_norm kernel with fast method (#42491)
上级
798e2e7e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
92 addition
and
31 deletion
+92
-31
paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h
...d/operators/fused/fused_layernorm_residual_dropout_bias.h
+1
-1
paddle/fluid/operators/layer_norm_kernel.cu.h
paddle/fluid/operators/layer_norm_kernel.cu.h
+83
-29
python/paddle/fluid/tests/unittests/test_layer_norm_op.py
python/paddle/fluid/tests/unittests/test_layer_norm_op.py
+8
-1
未找到文件。
paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h
浏览文件 @
b4a93884
...
@@ -541,7 +541,7 @@ void LaunchLayernormResidualDropoutGrad(
...
@@ -541,7 +541,7 @@ void LaunchLayernormResidualDropoutGrad(
if
(
!
is_upscale_in_train
)
{
if
(
!
is_upscale_in_train
)
{
factor
=
static_cast
<
T
>
(
1.0
f
);
factor
=
static_cast
<
T
>
(
1.0
f
);
}
}
ln_bwd_
1024
_kernel_driver
<
ln_bwd_
fast
_kernel_driver
<
T
,
U
,
LayerNormScaleBiasT
<
T
,
U
,
ScaleBiasWithSameTypeX
>
,
MaskType
>
(
T
,
U
,
LayerNormScaleBiasT
<
T
,
U
,
ScaleBiasWithSameTypeX
>
,
MaskType
>
(
dev_ctx
,
rows
,
cols
,
epsilon
,
layernorm_src
,
scale
,
mean
,
var
,
d_out
,
dev_ctx
,
rows
,
cols
,
epsilon
,
layernorm_src
,
scale
,
mean
,
var
,
d_out
,
d_residual
,
d_scale
,
d_layernorm_bias
,
mask_data
,
factor
,
d_dropout_src
);
d_residual
,
d_scale
,
d_layernorm_bias
,
mask_data
,
factor
,
d_dropout_src
);
...
...
paddle/fluid/operators/layer_norm_kernel.cu.h
浏览文件 @
b4a93884
...
@@ -22,6 +22,8 @@ limitations under the License. */
...
@@ -22,6 +22,8 @@ limitations under the License. */
namespace
cub
=
hipcub
;
namespace
cub
=
hipcub
;
#endif
#endif
#include <iostream>
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/ddim.h"
...
@@ -428,7 +430,7 @@ template <
...
@@ -428,7 +430,7 @@ template <
int
THREADS_PER_CTA
=
WARPS_M
*
THREADS_PER_ROW
,
int
ROWS_PER_CTA
=
WARPS_M
,
int
THREADS_PER_CTA
=
WARPS_M
*
THREADS_PER_ROW
,
int
ROWS_PER_CTA
=
WARPS_M
,
int
ELTS_PER_ROW_PER_CTA
=
THREADS_PER_ROW
*
VecSize
,
int
ELTS_PER_ROW_PER_CTA
=
THREADS_PER_ROW
*
VecSize
,
int
LDGS
=
ELTS_PER_ROW
/
ELTS_PER_ROW_PER_CTA
>
int
LDGS
=
ELTS_PER_ROW
/
ELTS_PER_ROW_PER_CTA
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
)
void
fused_ln_bwd_
1024
_kernel
(
__global__
__launch_bounds__
(
THREADS_PER_CTA
)
void
fused_ln_bwd_
fast
_kernel
(
const
int
rows
,
float
epsilon
,
const
T
*
__restrict__
x_ptr
,
const
int
rows
,
float
epsilon
,
const
T
*
__restrict__
x_ptr
,
const
ScaleT
*
__restrict__
gamma_ptr
,
const
U
*
__restrict__
mean_ptr
,
const
ScaleT
*
__restrict__
gamma_ptr
,
const
U
*
__restrict__
mean_ptr
,
const
U
*
__restrict__
var_ptr
,
const
T
*
__restrict__
dout_ptr
,
const
U
*
__restrict__
var_ptr
,
const
T
*
__restrict__
dout_ptr
,
...
@@ -671,7 +673,7 @@ template <
...
@@ -671,7 +673,7 @@ template <
int
ELTS_PER_ROW_PER_CTA
=
THREADS_PER_ROW
*
VecSize
,
int
ELTS_PER_ROW_PER_CTA
=
THREADS_PER_ROW
*
VecSize
,
int
LDGS
=
ELTS_PER_ROW
/
ELTS_PER_ROW_PER_CTA
,
int
LDGS
=
ELTS_PER_ROW
/
ELTS_PER_ROW_PER_CTA
,
int
VEC_COLS
=
ELTS_PER_ROW
/
VecSize
>
int
VEC_COLS
=
ELTS_PER_ROW
/
VecSize
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
)
void
ln_bwd_
1024
_final_kernel
(
__global__
__launch_bounds__
(
THREADS_PER_CTA
)
void
ln_bwd_
fast
_final_kernel
(
const
int
rows
,
U
*
__restrict__
dg_part_
,
U
*
__restrict__
db_part_
,
const
int
rows
,
U
*
__restrict__
dg_part_
,
U
*
__restrict__
db_part_
,
ScaleT
*
__restrict__
dg_
,
ScaleT
*
__restrict__
db_
)
{
ScaleT
*
__restrict__
dg_
,
ScaleT
*
__restrict__
db_
)
{
using
Vec
=
phi
::
AlignedVector
<
U
,
VecSize
>
;
using
Vec
=
phi
::
AlignedVector
<
U
,
VecSize
>
;
...
@@ -795,7 +797,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
...
@@ -795,7 +797,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
*/
*/
template
<
typename
T
,
typename
U
,
typename
ScaleT
=
U
,
template
<
typename
T
,
typename
U
,
typename
ScaleT
=
U
,
typename
MaskType
=
uint8_t
>
typename
MaskType
=
uint8_t
>
void
ln_bwd_
1024
_kernel_driver
(
const
phi
::
GPUContext
&
dev_ctx
,
const
int
rows
,
void
ln_bwd_
fast
_kernel_driver
(
const
phi
::
GPUContext
&
dev_ctx
,
const
int
rows
,
const
int
cols
,
float
epsilon
,
const
T
*
x_ptr
,
const
int
cols
,
float
epsilon
,
const
T
*
x_ptr
,
const
ScaleT
*
scale_ptr
,
const
U
*
mean_ptr
,
const
ScaleT
*
scale_ptr
,
const
U
*
mean_ptr
,
const
U
*
var_ptr
,
const
T
*
dout_ptr
,
T
*
dx_ptr
,
const
U
*
var_ptr
,
const
T
*
dout_ptr
,
T
*
dx_ptr
,
...
@@ -804,10 +806,10 @@ void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
...
@@ -804,10 +806,10 @@ void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
T
factor
=
static_cast
<
T
>
(
0
),
T
factor
=
static_cast
<
T
>
(
0
),
T
*
d_dropout_src_ptr
=
nullptr
)
{
T
*
d_dropout_src_ptr
=
nullptr
)
{
auto
stream
=
dev_ctx
.
stream
();
auto
stream
=
dev_ctx
.
stream
();
if
(
cols
==
1024
)
{
if
(
cols
==
1024
||
cols
==
384
||
cols
==
256
)
{
// step-1: compute dx and reduced part results of dscale and dbias.
// step-1: compute dx and reduced part results of dscale and dbias.
const
int
WARPS_M
=
4
;
const
int
WARPS_M
=
4
;
// how many rows delt in a cta.
const
int
WARPS_N
=
1
;
const
int
WARPS_N
=
1
;
// how many warps to deal with a row.
const
int
BYTES_PER_LDG
=
16
;
const
int
BYTES_PER_LDG
=
16
;
const
int
VecSize
=
BYTES_PER_LDG
/
sizeof
(
T
);
const
int
VecSize
=
BYTES_PER_LDG
/
sizeof
(
T
);
...
@@ -839,20 +841,52 @@ void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
...
@@ -839,20 +841,52 @@ void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
"To compute fused_dropout_residual_ln grad, d_dropout_src_ptr "
"To compute fused_dropout_residual_ln grad, d_dropout_src_ptr "
"can't be null"
));
"can't be null"
));
}
}
fused_ln_bwd_1024_kernel
<
true
,
T
,
U
,
ScaleT
,
MaskType
,
VecSize
,
WARPS_M
,
#define LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL(vec_size, ele_per_row) \
WARPS_N
,
BYTES_PER_LDG
>
fused_ln_bwd_fast_kernel<true, T, U, ScaleT, MaskType, vec_size, WARPS_M, \
<<<
gridx
,
THREADS_PER_CTA
,
0
,
stream
>>>
(
WARPS_N, BYTES_PER_LDG, ele_per_row> \
rows
,
epsilon
,
x_ptr
,
scale_ptr
,
mean_ptr
,
var_ptr
,
dout_ptr
,
<<<gridx, THREADS_PER_CTA, 0, stream>>>( \
dscale_temp_ptr
,
dbias_temp_ptr
,
dx_ptr
,
mask_ptr
,
factor
,
rows, epsilon, x_ptr, scale_ptr, mean_ptr, var_ptr, dout_ptr, \
d_dropout_src_ptr
);
dscale_temp_ptr, dbias_temp_ptr, dx_ptr, mask_ptr, factor, \
d_dropout_src_ptr);
if
(
cols
==
1024
)
{
LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL
(
VecSize
,
1024
);
}
else
{
switch
(
cols
)
{
case
384
:
LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL
(
1
,
384
);
break
;
case
256
:
LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL
(
VecSize
,
256
);
break
;
}
}
#undef LAUNCH_MASK_FUSED_LN_BWD_FAST_KERNEL
}
else
{
}
else
{
fused_ln_bwd_1024_kernel
<
false
,
T
,
U
,
ScaleT
,
MaskType
,
VecSize
,
WARPS_M
,
#define LAUNCH_FUSED_LN_BWD_FAST_KERNEL(vec_size, ele_per_row) \
WARPS_N
,
BYTES_PER_LDG
>
fused_ln_bwd_fast_kernel<false, T, U, ScaleT, MaskType, vec_size, WARPS_M, \
<<<
gridx
,
THREADS_PER_CTA
,
0
,
stream
>>>
(
WARPS_N, BYTES_PER_LDG, ele_per_row> \
rows
,
epsilon
,
x_ptr
,
scale_ptr
,
mean_ptr
,
var_ptr
,
dout_ptr
,
<<<gridx, THREADS_PER_CTA, 0, stream>>>( \
dscale_temp_ptr
,
dbias_temp_ptr
,
dx_ptr
);
rows, epsilon, x_ptr, scale_ptr, mean_ptr, var_ptr, dout_ptr, \
dscale_temp_ptr, dbias_temp_ptr, dx_ptr);
if
(
cols
==
1024
)
{
LAUNCH_FUSED_LN_BWD_FAST_KERNEL
(
VecSize
,
1024
);
}
else
{
switch
(
cols
)
{
case
384
:
LAUNCH_FUSED_LN_BWD_FAST_KERNEL
(
1
,
384
);
break
;
case
256
:
LAUNCH_FUSED_LN_BWD_FAST_KERNEL
(
VecSize
,
256
);
break
;
}
}
#undef LAUNCH_FUSED_LN_BWD_FAST_KERNEL
}
}
const
int
WARPS_M_2
=
16
;
const
int
WARPS_M_2
=
16
;
const
int
WARPS_N_2
=
1
;
const
int
WARPS_N_2
=
1
;
const
int
BYTES_PER_LDG_2
=
4
;
const
int
BYTES_PER_LDG_2
=
4
;
...
@@ -865,18 +899,36 @@ void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
...
@@ -865,18 +899,36 @@ void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
WARPS_M_2
*
THREADS_PER_ROW_2
;
// 16 * 32 = 512
WARPS_M_2
*
THREADS_PER_ROW_2
;
// 16 * 32 = 512
const
int
ROWS_PER_CTA_2
=
WARPS_M_2
;
// 16
const
int
ROWS_PER_CTA_2
=
WARPS_M_2
;
// 16
const
int
gridx_2
=
static_cast
<
int
>
(
std
::
ceil
(
1024
/
static_cast
<
float
>
(
THREADS_PER_ROW_2
*
VecSize_2
)));
// #blocks: 32,#threads_per_block: 512
// #blocks: 32,#threads_per_block: 512
// Note: it is not supported for double type.
// Note: it is not supported for double type.
if
(
sizeof
(
U
)
>
4
)
{
if
(
sizeof
(
U
)
>
4
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only support float and fp16 type"
));
"Only support float and fp16 type"
));
}
else
{
}
else
{
ln_bwd_1024_final_kernel
<
U
,
ScaleT
,
VecSize_2
,
WARPS_M_2
,
WARPS_N_2
,
int
gridx_2
=
0
;
BYTES_PER_LDG_2
>
<<<
gridx_2
,
THREADS_PER_CTA_2
,
0
,
stream
>>>
(
#define LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL(vec_size, ele_per_row) \
gridx
,
dscale_temp_ptr
,
dbias_temp_ptr
,
dscale_ptr
,
dbias_ptr
);
gridx_2 = static_cast<int>(std::ceil( \
ele_per_row / static_cast<float>(THREADS_PER_ROW_2 * vec_size))); \
ln_bwd_fast_final_kernel<U, ScaleT, vec_size, WARPS_M_2, WARPS_N_2, \
BYTES_PER_LDG_2, ele_per_row> \
<<<gridx_2, THREADS_PER_CTA_2, 0, stream>>>( \
gridx, dscale_temp_ptr, dbias_temp_ptr, dscale_ptr, dbias_ptr);
if
(
cols
==
1024
)
{
LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL
(
VecSize_2
,
1024
);
}
else
{
switch
(
cols
)
{
case
384
:
LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL
(
1
,
384
);
break
;
case
256
:
LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL
(
VecSize_2
,
256
);
break
;
}
}
#undef LAUNCH_LN_BWD_BETA_GAMMMA_KERNEL
}
}
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
...
@@ -1484,15 +1536,17 @@ static void LayerNormBackward(
...
@@ -1484,15 +1536,17 @@ static void LayerNormBackward(
case
7
:
// d_x != nullptr, d_scale != nullptr, d_bias != nullptr
case
7
:
// d_x != nullptr, d_scale != nullptr, d_bias != nullptr
{
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
bool
can_call_
1024
_kernel
=
false
;
bool
can_call_
fast
_kernel
=
false
;
// todo: rule out double type.
// todo: rule out double type.
if
(
feature_size
==
1024
&&
sizeof
(
T
)
<=
4
)
{
if
((
feature_size
==
1024
||
feature_size
==
384
||
can_call_1024_kernel
=
true
;
feature_size
==
256
)
&&
sizeof
(
T
)
<=
4
)
{
can_call_fast_kernel
=
true
;
}
}
VLOG
(
6
)
<<
"can_call_1024_kernel = "
<<
can_call_1024_kernel
;
if
(
can_call_1024_kernel
)
{
VLOG
(
6
)
<<
"can_call_fast_kernel = "
<<
can_call_fast_kernel
;
ln_bwd_1024_kernel_driver
<
if
(
can_call_fast_kernel
)
{
ln_bwd_fast_kernel_driver
<
T
,
U
,
LayerNormScaleBiasT
<
T
,
U
,
ScaleBiasWithSameTypeX
>>
(
T
,
U
,
LayerNormScaleBiasT
<
T
,
U
,
ScaleBiasWithSameTypeX
>>
(
dev_ctx
,
batch_size
,
feature_size
,
epsilon
,
x
,
scale
,
mean
,
var
,
dev_ctx
,
batch_size
,
feature_size
,
epsilon
,
x
,
scale
,
mean
,
var
,
d_y
,
d_x
,
d_scale
,
d_bias
);
d_y
,
d_x
,
d_scale
,
d_bias
);
...
...
python/paddle/fluid/tests/unittests/test_layer_norm_op.py
浏览文件 @
b4a93884
...
@@ -247,7 +247,6 @@ class TestLayerNormOp(unittest.TestCase):
...
@@ -247,7 +247,6 @@ class TestLayerNormOp(unittest.TestCase):
def
test_check_forward_backward_with_scale_and_bias
(
self
):
def
test_check_forward_backward_with_scale_and_bias
(
self
):
self
.
check_forward_backward
(
shape
=
[
1
,
3
,
4
,
5
],
begin_norm_axis
=
1
)
self
.
check_forward_backward
(
shape
=
[
1
,
3
,
4
,
5
],
begin_norm_axis
=
1
)
self
.
check_forward_backward
(
shape
=
[
2
,
3
,
4
,
5
],
begin_norm_axis
=
1
)
self
.
check_forward_backward
(
shape
=
[
2
,
3
,
4
,
5
],
begin_norm_axis
=
1
)
self
.
check_forward_backward
(
shape
=
[
2
,
3
,
4
,
5
],
self
.
check_forward_backward
(
shape
=
[
2
,
3
,
4
,
5
],
begin_norm_axis
=
1
,
begin_norm_axis
=
1
,
...
@@ -288,6 +287,14 @@ class TestLayerNormOp(unittest.TestCase):
...
@@ -288,6 +287,14 @@ class TestLayerNormOp(unittest.TestCase):
begin_norm_axis
=
1
,
begin_norm_axis
=
1
,
has_scale
=
True
,
has_scale
=
True
,
has_bias
=
True
)
has_bias
=
True
)
self
.
check_forward_backward
(
shape
=
[
1
,
128
,
256
,
256
],
begin_norm_axis
=
3
,
has_scale
=
True
,
has_bias
=
True
)
self
.
check_forward_backward
(
shape
=
[
1
,
256
,
384
],
begin_norm_axis
=
2
,
has_scale
=
True
,
has_bias
=
True
)
class
TestLayerNormAPI
(
unittest
.
TestCase
):
class
TestLayerNormAPI
(
unittest
.
TestCase
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录