Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
7cb49539
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7cb49539
编写于
4月 28, 2022
作者:
Z
Zhang Zheng
提交者:
GitHub
4月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Suppport more scenes for fused_fast_ln (#42282)
* Suppport more scenes for fused_fast_ln * fix
上级
687219fe
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
119 addition
and
46 deletion
+119
-46
paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h
...d/operators/fused/fused_layernorm_residual_dropout_bias.h
+119
-46
未找到文件。
paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h
浏览文件 @
7cb49539
...
...
@@ -156,9 +156,9 @@ __global__ void FusedLayernormResidualDropoutBias(
}
/*
* @brief layernorm(residual + dropout(x));
* @brief layernorm(residual + dropout(x));
* Conditions:
* (1) The number of cols is
1024
;
* (1) The number of cols is
768/1024/4096
;
* (2) layer_norm scale and bias is not null;
* (3) linear bias is null;
* @param
...
...
@@ -166,6 +166,7 @@ __global__ void FusedLayernormResidualDropoutBias(
* cols: 1024
* x_: [rows, cols], inputs
* residual_:[rows, cols]
* bias_: [cols], linear bias, can be null
* gamma_: [cols]: layernorm scale, not null
* beta_: [cols], layernorm bias, not null
* mask_out_: [rows, cols], dropout result
...
...
@@ -173,7 +174,7 @@ __global__ void FusedLayernormResidualDropoutBias(
* y_: [rows, cols], layernorm result
* mean_out_: [rows]: layernorm means
* var_out_: [rows]: layernorm vars
*/
*/
template
<
typename
T
,
typename
U
,
typename
ScaleT
=
U
,
typename
MaskType
=
uint8_t
,
int
VecSize
=
8
,
int
WARPS_M
=
4
,
int
WARPS_N
=
1
,
int
BYTES_PER_LDG
=
16
,
...
...
@@ -182,14 +183,16 @@ template <
int
THREADS_PER_CTA
=
WARPS_M
*
THREADS_PER_ROW
,
int
ROWS_PER_CTA
=
WARPS_M
,
int
ELTS_PER_ROW_PER_CTA
=
THREADS_PER_ROW
*
VecSize
,
int
LDGS
=
ELTS_PER_ROW
/
ELTS_PER_ROW_PER_CTA
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
)
void
fused_
ln_fwd_1024
_kernel
(
__global__
__launch_bounds__
(
THREADS_PER_CTA
)
void
fused_
fast_ln_fwd
_kernel
(
int
rows
,
int
cols
,
uint64_t
seed
,
const
float
dropout_prob
,
const
bool
is_upscale_in_train
,
const
bool
is_test
,
const
uint64_t
increment
,
const
float
epsilon
,
const
T
*
__restrict__
x_ptr
,
const
T
*
__restrict__
residual_ptr
,
const
ScaleT
*
__restrict__
gamma_ptr
,
const
ScaleT
*
__restrict__
beta_ptr
,
MaskType
*
__restrict__
mask_out_ptr
,
U
*
__restrict__
mean_out_ptr
,
U
*
__restrict__
var_out_ptr
,
T
*
__restrict__
residual_out_ptr
,
T
*
__restrict__
y_ptr
)
{
const
T
*
__restrict__
residual_ptr
,
const
T
*
__restrict__
bias_ptr
,
const
ScaleT
*
__restrict__
gamma_ptr
,
const
ScaleT
*
__restrict__
beta_ptr
,
MaskType
*
__restrict__
mask_out_ptr
,
U
*
__restrict__
mean_out_ptr
,
U
*
__restrict__
var_out_ptr
,
T
*
__restrict__
residual_out_ptr
,
T
*
__restrict__
y_ptr
)
{
__shared__
U
smem
[
WARPS_M
*
WARPS_N
];
using
Vec
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
using
Vec_scale
=
phi
::
AlignedVector
<
ScaleT
,
VecSize
>
;
using
MaskStoreT
=
phi
::
AlignedVector
<
MaskType
,
VecSize
>
;
...
...
@@ -204,12 +207,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
const
int
c
=
warp_n
*
THREADS_PER_WARP
+
lane
;
// lane
const
int
r
=
bidx
*
ROWS_PER_CTA
+
warp_m
;
// row id
int
idx
=
r
*
LN_NUM_COLS
+
c
;
int
idx
=
r
*
ELTS_PER_ROW
+
c
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
idx
,
increment
,
&
state
);
T
factor
=
GetFactor
<
T
>
(
dropout_prob
,
is_upscale_in_train
,
is_test
);
// bias
Vec
bias
[
LDGS
];
if
(
bias_ptr
!=
nullptr
)
{
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
phi
::
Load
<
T
,
VecSize
>
(
bias_ptr
+
col
*
VecSize
,
&
bias
[
it
]);
col
+=
THREADS_PER_ROW
;
}
}
Vec_scale
gamma
[
LDGS
];
Vec_scale
beta
[
LDGS
];
#pragma unroll
...
...
@@ -219,14 +232,14 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
col
+=
THREADS_PER_ROW
;
}
constexpr
U
rn
=
1.
f
/
U
(
LN_NUM_COLS
);
constexpr
U
rn
=
1.
f
/
U
(
ELTS_PER_ROW
);
for
(
int
row
=
r
;
row
<
rows
;
row
+=
gridDim
.
x
*
ROWS_PER_CTA
)
{
Vec
x
[
LDGS
];
Vec
residual
[
LDGS
];
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
phi
::
Load
<
T
,
VecSize
>
(
x_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
,
&
x
[
it
]);
phi
::
Load
<
T
,
VecSize
>
(
residual_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
,
phi
::
Load
<
T
,
VecSize
>
(
x_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
,
&
x
[
it
]);
phi
::
Load
<
T
,
VecSize
>
(
residual_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
,
&
residual
[
it
]);
col
+=
THREADS_PER_ROW
;
}
...
...
@@ -255,14 +268,28 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
// 4 * 8
U
xf
[
LDGS
*
VecSize
];
if
(
bias_ptr
!=
nullptr
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
VecSize
;
jt
++
)
{
// dropout(x) + residual
x
[
it
][
jt
]
=
x
[
it
][
jt
]
*
static_cast
<
T
>
(
mask_vec
[
it
][
jt
])
*
factor
+
residual
[
it
][
jt
];
xf
[
it
*
VecSize
+
jt
]
=
U
(
x
[
it
][
jt
]);
for
(
int
jt
=
0
;
jt
<
VecSize
;
jt
++
)
{
// dropout(x) + residual
x
[
it
][
jt
]
=
(
x
[
it
][
jt
]
+
bias
[
it
][
jt
])
*
static_cast
<
T
>
(
mask_vec
[
it
][
jt
])
*
factor
+
residual
[
it
][
jt
];
xf
[
it
*
VecSize
+
jt
]
=
U
(
x
[
it
][
jt
]);
}
}
}
else
{
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
VecSize
;
jt
++
)
{
// dropout(x) + residual
x
[
it
][
jt
]
=
x
[
it
][
jt
]
*
static_cast
<
T
>
(
mask_vec
[
it
][
jt
])
*
factor
+
residual
[
it
][
jt
];
xf
[
it
*
VecSize
+
jt
]
=
U
(
x
[
it
][
jt
]);
}
}
}
...
...
@@ -270,9 +297,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
phi
::
Store
<
T
,
VecSize
>
(
x
[
it
],
residual_out_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
);
x
[
it
],
residual_out_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
);
phi
::
Store
<
MaskType
,
VecSize
>
(
mask_vec
[
it
],
mask_out_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
);
mask_vec
[
it
],
mask_out_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
);
col
+=
THREADS_PER_ROW
;
}
...
...
@@ -289,6 +316,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
for
(
int
it
=
1
;
it
<
THREADS_PER_WARP
;
it
*=
2
)
{
mu_local
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
mu_local
,
it
);
}
if
(
WARPS_N
>
1
)
{
if
(
lane
==
0
)
{
smem
[
warp_m
*
WARPS_N
+
warp_n
]
=
mu_local
;
}
__syncthreads
();
if
(
tidx
==
0
)
{
mu_local
=
0.
f
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARPS_N
;
++
it
)
{
mu_local
+=
smem
[
warp_m
*
WARPS_N
+
it
];
}
smem
[
warp_m
]
=
mu_local
;
}
__syncthreads
();
mu_local
=
smem
[
warp_m
];
}
mu_local
*=
rn
;
if
(
lane
==
0
)
{
mean_out_ptr
[
row
]
=
mu_local
;
...
...
@@ -308,6 +351,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
for
(
int
it
=
1
;
it
<
THREADS_PER_WARP
;
it
*=
2
)
{
var_local
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
var_local
,
it
);
}
if
(
WARPS_N
>
1
)
{
if
(
lane
==
0
)
{
smem
[
warp_m
*
WARPS_N
+
warp_n
]
=
var_local
;
}
__syncthreads
();
if
(
tidx
==
0
)
{
var_local
=
0.
f
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARPS_N
;
++
it
)
{
var_local
+=
smem
[
warp_m
*
WARPS_N
+
it
];
}
smem
[
warp_m
]
=
var_local
;
}
__syncthreads
();
var_local
=
smem
[
warp_m
];
}
U
rsigma
=
rsqrtf
(
var_local
*
rn
+
epsilon
);
if
(
lane
==
0
)
{
// Note: the stored var is different for paddle(ln) and apex (fast ln).
...
...
@@ -332,7 +391,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
phi
::
Store
<
T
,
VecSize
>
(
x
[
it
],
y_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
);
phi
::
Store
<
T
,
VecSize
>
(
x
[
it
],
y_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
);
col
+=
THREADS_PER_ROW
;
}
}
...
...
@@ -390,12 +449,37 @@ void LaunchLayernormResidualDropoutBias(
return
;
}
bool
can_call_1024_kernel
=
false
;
if
(
cols
==
1024
&&
scale
!=
nullptr
&&
layernorm_bias
!=
nullptr
&&
bias
==
nullptr
)
{
can_call_1024_kernel
=
true
;
#define LAUNCH_FUSED_FAST_LN_KERNEL_BASE(cols) \
case (cols): { \
constexpr int WARPS_N = cols < 1024 ? 1 : (cols / 1024); \
constexpr int WARPS_M = 4 / WARPS_N; \
const int THREADS_PER_WARP = 32; \
const int BYTES_PER_LDG = 16; \
const int VecSize = BYTES_PER_LDG / sizeof(T); \
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; \
const int ROWS_PER_CTA = WARPS_M; \
const int grid = \
static_cast<int>(std::ceil(rows / static_cast<float>(ROWS_PER_CTA))); \
fused_fast_ln_fwd_kernel< \
T, U, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, uint8_t, \
VecSize, WARPS_M, WARPS_N, BYTES_PER_LDG, \
cols><<<grid, THREADS_PER_CTA, 0, ctx.stream()>>>( \
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, \
increment, epsilon, src, residual, bias, scale, layernorm_bias, \
mask_data, mean, var, dst, layernorm_dst); \
} break
#define LAUNCH_FUSED_FAST_LN_KERNEL \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(768); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1024); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(4096)
bool
can_call_fast_ln_kernel
=
false
;
if
((
cols
==
768
||
cols
==
1024
||
cols
==
4096
)
&&
scale
!=
nullptr
&&
layernorm_bias
!=
nullptr
)
{
can_call_fast_ln_kernel
=
true
;
}
VLOG
(
6
)
<<
"can_call_
1024_kernel = "
<<
can_call_1024
_kernel
;
VLOG
(
6
)
<<
"can_call_
fast_ln_kernel = "
<<
can_call_fast_ln
_kernel
;
const
int
VecSize
=
MAX_CACHE_BYTES
/
sizeof
(
T
);
if
(
cols
%
VecSize
!=
0
)
{
...
...
@@ -407,26 +491,15 @@ void LaunchLayernormResidualDropoutBias(
epsilon
,
src
,
residual
,
bias
,
scale
,
layernorm_bias
,
mask_data
,
dst
,
layernorm_dst
,
mean
,
var
);
}
else
{
if
(
can_call_1024_kernel
)
{
const
int
WARPS_M
=
4
;
const
int
WARPS_N
=
1
;
const
int
THREADS_PER_WARP
=
32
;
const
int
BYTES_PER_LDG
=
16
;
const
int
VecSize
=
BYTES_PER_LDG
/
sizeof
(
T
);
const
int
THREADS_PER_CTA
=
WARPS_N
*
THREADS_PER_WARP
*
WARPS_M
;
const
int
ROWS_PER_CTA
=
WARPS_M
;
// Note: the grid can not exceed max_grid of the gpu.
const
int
grid
=
static_cast
<
int
>
(
std
::
ceil
(
rows
/
static_cast
<
float
>
(
ROWS_PER_CTA
)));
fused_ln_fwd_1024_kernel
<
T
,
U
,
LayerNormScaleBiasT
<
T
,
U
,
ScaleBiasWithSameTypeX
>
,
uint8_t
,
VecSize
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
><<<
grid
,
THREADS_PER_CTA
,
0
,
ctx
.
stream
()
>>>
(
rows
,
cols
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
,
increment
,
epsilon
,
src
,
residual
,
scale
,
layernorm_bias
,
mask_data
,
mean
,
var
,
dst
,
layernorm_dst
);
if
(
can_call_fast_ln_kernel
)
{
switch
(
cols
)
{
LAUNCH_FUSED_FAST_LN_KERNEL
;
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only when column is equal to 768/1024/4096 is supported for "
"now"
));
break
;
}
}
else
{
int
blockDim
=
GetDesiredBlockDim
(
cols
/
VecSize
);
FusedLayernormResidualDropoutBias
<
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录