Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7cb49539
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7cb49539
编写于
4月 28, 2022
作者:
Z
Zhang Zheng
提交者:
GitHub
4月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Suppport more scenes for fused_fast_ln (#42282)
* Suppport more scenes for fused_fast_ln * fix
上级
687219fe
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
119 addition
and
46 deletion
+119
-46
paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h
...d/operators/fused/fused_layernorm_residual_dropout_bias.h
+119
-46
未找到文件。
paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h
浏览文件 @
7cb49539
...
...
@@ -156,9 +156,9 @@ __global__ void FusedLayernormResidualDropoutBias(
}
/*
* @brief layernorm(residual + dropout(x));
* @brief layernorm(residual + dropout(x));
* Conditions:
* (1) The number of cols is
1024
;
* (1) The number of cols is
768/1024/4096
;
* (2) layer_norm scale and bias is not null;
* (3) linear bias is null;
* @param
...
...
@@ -166,6 +166,7 @@ __global__ void FusedLayernormResidualDropoutBias(
* cols: 1024
* x_: [rows, cols], inputs
* residual_:[rows, cols]
* bias_: [cols], linear bias, can be null
* gamma_: [cols]: layernorm scale, not null
* beta_: [cols], layernorm bias, not null
* mask_out_: [rows, cols], dropout result
...
...
@@ -173,7 +174,7 @@ __global__ void FusedLayernormResidualDropoutBias(
* y_: [rows, cols], layernorm result
* mean_out_: [rows]: layernorm means
* var_out_: [rows]: layernorm vars
*/
*/
template
<
typename
T
,
typename
U
,
typename
ScaleT
=
U
,
typename
MaskType
=
uint8_t
,
int
VecSize
=
8
,
int
WARPS_M
=
4
,
int
WARPS_N
=
1
,
int
BYTES_PER_LDG
=
16
,
...
...
@@ -182,14 +183,16 @@ template <
int
THREADS_PER_CTA
=
WARPS_M
*
THREADS_PER_ROW
,
int
ROWS_PER_CTA
=
WARPS_M
,
int
ELTS_PER_ROW_PER_CTA
=
THREADS_PER_ROW
*
VecSize
,
int
LDGS
=
ELTS_PER_ROW
/
ELTS_PER_ROW_PER_CTA
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
)
void
fused_
ln_fwd_1024
_kernel
(
__global__
__launch_bounds__
(
THREADS_PER_CTA
)
void
fused_
fast_ln_fwd
_kernel
(
int
rows
,
int
cols
,
uint64_t
seed
,
const
float
dropout_prob
,
const
bool
is_upscale_in_train
,
const
bool
is_test
,
const
uint64_t
increment
,
const
float
epsilon
,
const
T
*
__restrict__
x_ptr
,
const
T
*
__restrict__
residual_ptr
,
const
ScaleT
*
__restrict__
gamma_ptr
,
const
ScaleT
*
__restrict__
beta_ptr
,
MaskType
*
__restrict__
mask_out_ptr
,
U
*
__restrict__
mean_out_ptr
,
U
*
__restrict__
var_out_ptr
,
T
*
__restrict__
residual_out_ptr
,
T
*
__restrict__
y_ptr
)
{
const
T
*
__restrict__
residual_ptr
,
const
T
*
__restrict__
bias_ptr
,
const
ScaleT
*
__restrict__
gamma_ptr
,
const
ScaleT
*
__restrict__
beta_ptr
,
MaskType
*
__restrict__
mask_out_ptr
,
U
*
__restrict__
mean_out_ptr
,
U
*
__restrict__
var_out_ptr
,
T
*
__restrict__
residual_out_ptr
,
T
*
__restrict__
y_ptr
)
{
__shared__
U
smem
[
WARPS_M
*
WARPS_N
];
using
Vec
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
using
Vec_scale
=
phi
::
AlignedVector
<
ScaleT
,
VecSize
>
;
using
MaskStoreT
=
phi
::
AlignedVector
<
MaskType
,
VecSize
>
;
...
...
@@ -204,12 +207,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
const
int
c
=
warp_n
*
THREADS_PER_WARP
+
lane
;
// lane
const
int
r
=
bidx
*
ROWS_PER_CTA
+
warp_m
;
// row id
int
idx
=
r
*
LN_NUM_COLS
+
c
;
int
idx
=
r
*
ELTS_PER_ROW
+
c
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
idx
,
increment
,
&
state
);
T
factor
=
GetFactor
<
T
>
(
dropout_prob
,
is_upscale_in_train
,
is_test
);
// bias
Vec
bias
[
LDGS
];
if
(
bias_ptr
!=
nullptr
)
{
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
phi
::
Load
<
T
,
VecSize
>
(
bias_ptr
+
col
*
VecSize
,
&
bias
[
it
]);
col
+=
THREADS_PER_ROW
;
}
}
Vec_scale
gamma
[
LDGS
];
Vec_scale
beta
[
LDGS
];
#pragma unroll
...
...
@@ -219,14 +232,14 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
col
+=
THREADS_PER_ROW
;
}
constexpr
U
rn
=
1.
f
/
U
(
LN_NUM_COLS
);
constexpr
U
rn
=
1.
f
/
U
(
ELTS_PER_ROW
);
for
(
int
row
=
r
;
row
<
rows
;
row
+=
gridDim
.
x
*
ROWS_PER_CTA
)
{
Vec
x
[
LDGS
];
Vec
residual
[
LDGS
];
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
phi
::
Load
<
T
,
VecSize
>
(
x_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
,
&
x
[
it
]);
phi
::
Load
<
T
,
VecSize
>
(
residual_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
,
phi
::
Load
<
T
,
VecSize
>
(
x_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
,
&
x
[
it
]);
phi
::
Load
<
T
,
VecSize
>
(
residual_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
,
&
residual
[
it
]);
col
+=
THREADS_PER_ROW
;
}
...
...
@@ -255,14 +268,28 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
// 4 * 8
U
xf
[
LDGS
*
VecSize
];
if
(
bias_ptr
!=
nullptr
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
VecSize
;
jt
++
)
{
// dropout(x) + residual
x
[
it
][
jt
]
=
x
[
it
][
jt
]
*
static_cast
<
T
>
(
mask_vec
[
it
][
jt
])
*
factor
+
residual
[
it
][
jt
];
xf
[
it
*
VecSize
+
jt
]
=
U
(
x
[
it
][
jt
]);
for
(
int
jt
=
0
;
jt
<
VecSize
;
jt
++
)
{
// dropout(x) + residual
x
[
it
][
jt
]
=
(
x
[
it
][
jt
]
+
bias
[
it
][
jt
])
*
static_cast
<
T
>
(
mask_vec
[
it
][
jt
])
*
factor
+
residual
[
it
][
jt
];
xf
[
it
*
VecSize
+
jt
]
=
U
(
x
[
it
][
jt
]);
}
}
}
else
{
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
VecSize
;
jt
++
)
{
// dropout(x) + residual
x
[
it
][
jt
]
=
x
[
it
][
jt
]
*
static_cast
<
T
>
(
mask_vec
[
it
][
jt
])
*
factor
+
residual
[
it
][
jt
];
xf
[
it
*
VecSize
+
jt
]
=
U
(
x
[
it
][
jt
]);
}
}
}
...
...
@@ -270,9 +297,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
phi
::
Store
<
T
,
VecSize
>
(
x
[
it
],
residual_out_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
);
x
[
it
],
residual_out_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
);
phi
::
Store
<
MaskType
,
VecSize
>
(
mask_vec
[
it
],
mask_out_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
);
mask_vec
[
it
],
mask_out_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
);
col
+=
THREADS_PER_ROW
;
}
...
...
@@ -289,6 +316,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
for
(
int
it
=
1
;
it
<
THREADS_PER_WARP
;
it
*=
2
)
{
mu_local
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
mu_local
,
it
);
}
if
(
WARPS_N
>
1
)
{
if
(
lane
==
0
)
{
smem
[
warp_m
*
WARPS_N
+
warp_n
]
=
mu_local
;
}
__syncthreads
();
if
(
tidx
==
0
)
{
mu_local
=
0.
f
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARPS_N
;
++
it
)
{
mu_local
+=
smem
[
warp_m
*
WARPS_N
+
it
];
}
smem
[
warp_m
]
=
mu_local
;
}
__syncthreads
();
mu_local
=
smem
[
warp_m
];
}
mu_local
*=
rn
;
if
(
lane
==
0
)
{
mean_out_ptr
[
row
]
=
mu_local
;
...
...
@@ -308,6 +351,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
for
(
int
it
=
1
;
it
<
THREADS_PER_WARP
;
it
*=
2
)
{
var_local
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
var_local
,
it
);
}
if
(
WARPS_N
>
1
)
{
if
(
lane
==
0
)
{
smem
[
warp_m
*
WARPS_N
+
warp_n
]
=
var_local
;
}
__syncthreads
();
if
(
tidx
==
0
)
{
var_local
=
0.
f
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARPS_N
;
++
it
)
{
var_local
+=
smem
[
warp_m
*
WARPS_N
+
it
];
}
smem
[
warp_m
]
=
var_local
;
}
__syncthreads
();
var_local
=
smem
[
warp_m
];
}
U
rsigma
=
rsqrtf
(
var_local
*
rn
+
epsilon
);
if
(
lane
==
0
)
{
// Note: the stored var is different for paddle(ln) and apex (fast ln).
...
...
@@ -332,7 +391,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
phi
::
Store
<
T
,
VecSize
>
(
x
[
it
],
y_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
);
phi
::
Store
<
T
,
VecSize
>
(
x
[
it
],
y_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
);
col
+=
THREADS_PER_ROW
;
}
}
...
...
@@ -390,12 +449,37 @@ void LaunchLayernormResidualDropoutBias(
return
;
}
bool
can_call_1024_kernel
=
false
;
if
(
cols
==
1024
&&
scale
!=
nullptr
&&
layernorm_bias
!=
nullptr
&&
bias
==
nullptr
)
{
can_call_1024_kernel
=
true
;
#define LAUNCH_FUSED_FAST_LN_KERNEL_BASE(cols) \
case (cols): { \
constexpr int WARPS_N = cols < 1024 ? 1 : (cols / 1024); \
constexpr int WARPS_M = 4 / WARPS_N; \
const int THREADS_PER_WARP = 32; \
const int BYTES_PER_LDG = 16; \
const int VecSize = BYTES_PER_LDG / sizeof(T); \
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; \
const int ROWS_PER_CTA = WARPS_M; \
const int grid = \
static_cast<int>(std::ceil(rows / static_cast<float>(ROWS_PER_CTA))); \
fused_fast_ln_fwd_kernel< \
T, U, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, uint8_t, \
VecSize, WARPS_M, WARPS_N, BYTES_PER_LDG, \
cols><<<grid, THREADS_PER_CTA, 0, ctx.stream()>>>( \
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, \
increment, epsilon, src, residual, bias, scale, layernorm_bias, \
mask_data, mean, var, dst, layernorm_dst); \
} break
#define LAUNCH_FUSED_FAST_LN_KERNEL \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(768); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1024); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(4096)
bool
can_call_fast_ln_kernel
=
false
;
if
((
cols
==
768
||
cols
==
1024
||
cols
==
4096
)
&&
scale
!=
nullptr
&&
layernorm_bias
!=
nullptr
)
{
can_call_fast_ln_kernel
=
true
;
}
VLOG
(
6
)
<<
"can_call_
1024_kernel = "
<<
can_call_1024
_kernel
;
VLOG
(
6
)
<<
"can_call_
fast_ln_kernel = "
<<
can_call_fast_ln
_kernel
;
const
int
VecSize
=
MAX_CACHE_BYTES
/
sizeof
(
T
);
if
(
cols
%
VecSize
!=
0
)
{
...
...
@@ -407,26 +491,15 @@ void LaunchLayernormResidualDropoutBias(
epsilon
,
src
,
residual
,
bias
,
scale
,
layernorm_bias
,
mask_data
,
dst
,
layernorm_dst
,
mean
,
var
);
}
else
{
if
(
can_call_1024_kernel
)
{
const
int
WARPS_M
=
4
;
const
int
WARPS_N
=
1
;
const
int
THREADS_PER_WARP
=
32
;
const
int
BYTES_PER_LDG
=
16
;
const
int
VecSize
=
BYTES_PER_LDG
/
sizeof
(
T
);
const
int
THREADS_PER_CTA
=
WARPS_N
*
THREADS_PER_WARP
*
WARPS_M
;
const
int
ROWS_PER_CTA
=
WARPS_M
;
// Note: the grid can not exceed max_grid of the gpu.
const
int
grid
=
static_cast
<
int
>
(
std
::
ceil
(
rows
/
static_cast
<
float
>
(
ROWS_PER_CTA
)));
fused_ln_fwd_1024_kernel
<
T
,
U
,
LayerNormScaleBiasT
<
T
,
U
,
ScaleBiasWithSameTypeX
>
,
uint8_t
,
VecSize
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
><<<
grid
,
THREADS_PER_CTA
,
0
,
ctx
.
stream
()
>>>
(
rows
,
cols
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
,
increment
,
epsilon
,
src
,
residual
,
scale
,
layernorm_bias
,
mask_data
,
mean
,
var
,
dst
,
layernorm_dst
);
if
(
can_call_fast_ln_kernel
)
{
switch
(
cols
)
{
LAUNCH_FUSED_FAST_LN_KERNEL
;
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only when column is equal to 768/1024/4096 is supported for "
"now"
));
break
;
}
}
else
{
int
blockDim
=
GetDesiredBlockDim
(
cols
/
VecSize
);
FusedLayernormResidualDropoutBias
<
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录