Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
01d04be6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
01d04be6
编写于
1月 26, 2022
作者:
L
Li Min
提交者:
GitHub
1月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize layer norm forward when cols is 1024. (#39167)
* Optimize layer_norm fwd when cols is 1024.
上级
6efb9f59
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
395 addition
and
18 deletion
+395
-18
paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h
...d/operators/fused/fused_layernorm_residual_dropout_bias.h
+223
-7
paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu
...ators/fused/fused_layernorm_residual_dropout_bias_test.cu
+14
-8
paddle/fluid/operators/layer_norm_kernel.cu.h
paddle/fluid/operators/layer_norm_kernel.cu.h
+115
-0
paddle/fluid/operators/layer_norm_op.cu
paddle/fluid/operators/layer_norm_op.cu
+41
-3
python/paddle/fluid/tests/unittests/test_layer_norm_op.py
python/paddle/fluid/tests/unittests/test_layer_norm_op.py
+2
-0
未找到文件。
paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h
浏览文件 @
01d04be6
...
...
@@ -19,6 +19,8 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
#define LN_NUM_COLS 1024
template
<
typename
T
>
using
CudnnDataType
=
platform
::
CudnnDataType
<
T
>
;
template
<
typename
T
>
...
...
@@ -153,6 +155,191 @@ __global__ void FusedLayernormResidualDropoutBias(
invvar
);
}
/*
* @brief layernorm(residual + dropout(x));
* Conditions:
* (1) The number of cols is 1024;
* (2) layer_norm scale and bias is not null;
* (3) linear bias is null;
* @param
* rows: batch_size * seq_len
* cols: 1024
* x_: [rows, cols], inputs
* residual_:[rows, cols]
* gamma_: [cols]: layernorm scale, not null
* beta_: [cols], layernorm bias, not null
* mask_out_: [rows, cols], dropout result
* residual_out_: [rows, cols], residual + dropout(src)
* y_: [rows, cols], layernorm result
* mean_out_: [rows]: layernorm means
* var_out_: [rows]: layernorm vars
*/
template
<
typename
T
,
typename
U
,
typename
ScaleT
=
U
,
typename
MaskType
=
uint8_t
,
int
VecSize
=
8
,
int
WARPS_M
=
4
,
int
WARPS_N
=
1
,
int
BYTES_PER_LDG
=
16
,
int
ELTS_PER_ROW
=
1024
,
int
THREADS_PER_WARP
=
32
,
int
THREADS_PER_ROW
=
WARPS_N
*
THREADS_PER_WARP
,
int
THREADS_PER_CTA
=
WARPS_M
*
THREADS_PER_ROW
,
int
ROWS_PER_CTA
=
WARPS_M
,
int
ELTS_PER_ROW_PER_CTA
=
THREADS_PER_ROW
*
VecSize
,
int
LDGS
=
ELTS_PER_ROW
/
ELTS_PER_ROW_PER_CTA
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
)
void
fused_ln_fwd_1024_kernel
(
int
rows
,
int
cols
,
uint64_t
seed
,
const
float
dropout_prob
,
const
bool
is_upscale_in_train
,
const
bool
is_test
,
const
uint64_t
increment
,
const
float
epsilon
,
const
T
*
__restrict__
x_ptr
,
const
T
*
__restrict__
residual_ptr
,
const
ScaleT
*
__restrict__
gamma_ptr
,
const
ScaleT
*
__restrict__
beta_ptr
,
MaskType
*
__restrict__
mask_out_ptr
,
U
*
__restrict__
mean_out_ptr
,
U
*
__restrict__
var_out_ptr
,
T
*
__restrict__
residual_out_ptr
,
T
*
__restrict__
y_ptr
)
{
using
Vec
=
platform
::
AlignedVector
<
T
,
VecSize
>
;
using
Vec_scale
=
platform
::
AlignedVector
<
ScaleT
,
VecSize
>
;
using
MaskStoreT
=
platform
::
AlignedVector
<
MaskType
,
VecSize
>
;
const
int
tidx
=
threadIdx
.
x
;
const
int
bidx
=
blockIdx
.
x
;
const
int
lane
=
tidx
%
THREADS_PER_WARP
;
// 0, 1, ..., 31
const
int
warp
=
tidx
/
THREADS_PER_WARP
;
// 0, 1, 2, 3
const
int
warp_n
=
warp
%
WARPS_N
;
// 0
const
int
warp_m
=
warp
/
WARPS_N
;
// 0, 1, 2, 3
const
int
c
=
warp_n
*
THREADS_PER_WARP
+
lane
;
// lane
const
int
r
=
bidx
*
ROWS_PER_CTA
+
warp_m
;
// row id
int
idx
=
r
*
LN_NUM_COLS
+
c
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
idx
,
increment
,
&
state
);
T
factor
=
GetFactor
<
T
>
(
dropout_prob
,
is_upscale_in_train
,
is_test
);
Vec_scale
gamma
[
LDGS
];
Vec_scale
beta
[
LDGS
];
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
platform
::
Load
<
ScaleT
,
VecSize
>
(
gamma_ptr
+
col
*
VecSize
,
&
gamma
[
it
]);
platform
::
Load
<
ScaleT
,
VecSize
>
(
beta_ptr
+
col
*
VecSize
,
&
beta
[
it
]);
col
+=
THREADS_PER_ROW
;
}
constexpr
U
rn
=
1.
f
/
U
(
LN_NUM_COLS
);
for
(
int
row
=
r
;
row
<
rows
;
row
+=
gridDim
.
x
*
ROWS_PER_CTA
)
{
Vec
x
[
LDGS
];
Vec
residual
[
LDGS
];
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
platform
::
Load
<
T
,
VecSize
>
(
x_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
,
&
x
[
it
]);
platform
::
Load
<
T
,
VecSize
>
(
residual_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
,
&
residual
[
it
]);
col
+=
THREADS_PER_ROW
;
}
MaskStoreT
mask_vec
[
LDGS
];
if
(
!
is_test
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
float
rand
[
VecSize
];
RandVec
<
VecSize
>
(
&
state
,
rand
);
#pragma unroll
for
(
int
jt
=
0
;
jt
<
VecSize
;
jt
++
)
{
#pragma unroll
mask_vec
[
it
][
jt
]
=
static_cast
<
MaskType
>
(
rand
[
jt
]
>=
dropout_prob
);
}
}
}
else
{
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
VecSize
;
jt
++
)
{
mask_vec
[
it
][
jt
]
=
static_cast
<
MaskType
>
(
1
);
}
}
}
// 4 * 8
U
xf
[
LDGS
*
VecSize
];
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
VecSize
;
jt
++
)
{
// dropout(x) + residual
x
[
it
][
jt
]
=
x
[
it
][
jt
]
*
static_cast
<
T
>
(
mask_vec
[
it
][
jt
])
*
factor
+
residual
[
it
][
jt
];
xf
[
it
*
VecSize
+
jt
]
=
U
(
x
[
it
][
jt
]);
}
}
// store dropout_residual_out and mask_out
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
platform
::
Store
<
T
,
VecSize
>
(
x
[
it
],
residual_out_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
);
platform
::
Store
<
MaskType
,
VecSize
>
(
mask_vec
[
it
],
mask_out_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
);
col
+=
THREADS_PER_ROW
;
}
U
mu_local
=
0.
f
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
VecSize
;
jt
++
)
{
mu_local
+=
xf
[
it
*
VecSize
+
jt
];
}
}
#pragma unroll
for
(
int
it
=
1
;
it
<
THREADS_PER_WARP
;
it
*=
2
)
{
mu_local
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
mu_local
,
it
);
}
mu_local
*=
rn
;
if
(
lane
==
0
)
{
mean_out_ptr
[
row
]
=
mu_local
;
}
U
var_local
=
0.
f
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
VecSize
;
jt
++
)
{
U
diff
=
xf
[
it
*
VecSize
+
jt
]
-
mu_local
;
var_local
+=
diff
*
diff
;
}
}
#pragma unroll
for
(
int
it
=
1
;
it
<
THREADS_PER_WARP
;
it
*=
2
)
{
var_local
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
var_local
,
it
);
}
U
rsigma
=
rsqrtf
(
var_local
*
rn
+
epsilon
);
if
(
lane
==
0
)
{
// Note: the stored var is different for paddle(ln) and apex (fast ln).
// var_out_ptr[row] = rsigma;
var_out_ptr
[
row
]
=
var_local
*
rn
;
}
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
VecSize
;
jt
++
)
{
// use fp16 to compute
// ScaleT tmp = static_cast<ScaleT>(rsigma * (xf[it * VecSize + jt] -
// mu_local));
// x[it][jt] = gamma[it][jt] * tmp + beta[it][jt];
// cast to fp32 to compute
U
tmp
=
rsigma
*
(
static_cast
<
U
>
(
xf
[
it
*
VecSize
+
jt
])
-
mu_local
);
x
[
it
][
jt
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
gamma
[
it
][
jt
])
*
tmp
+
static_cast
<
U
>
(
beta
[
it
][
jt
]));
}
}
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
platform
::
Store
<
T
,
VecSize
>
(
x
[
it
],
y_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
);
col
+=
THREADS_PER_ROW
;
}
}
}
/**
* @brief layernorm(residual + dropout(src + bias));
* @param
...
...
@@ -205,6 +392,13 @@ void LaunchLayernormResidualDropoutBias(
return
;
}
bool
can_call_1024_kernel
=
false
;
if
(
cols
==
1024
&&
scale
!=
nullptr
&&
layernorm_bias
!=
nullptr
&&
bias
==
nullptr
)
{
can_call_1024_kernel
=
true
;
}
VLOG
(
6
)
<<
"can_call_1024_kernel = "
<<
can_call_1024_kernel
;
const
int
VecSize
=
MAX_CACHE_BYTES
/
sizeof
(
T
);
if
(
cols
%
VecSize
!=
0
)
{
int
blockDim
=
GetDesiredBlockDim
(
cols
);
...
...
@@ -215,13 +409,35 @@ void LaunchLayernormResidualDropoutBias(
epsilon
,
src
,
residual
,
bias
,
scale
,
layernorm_bias
,
mask_data
,
dst
,
layernorm_dst
,
mean
,
var
);
}
else
{
int
blockDim
=
GetDesiredBlockDim
(
cols
/
VecSize
);
FusedLayernormResidualDropoutBias
<
T
,
uint8_t
,
VecSize
,
U
,
ScaleBiasWithSameTypeX
><<<
rows
,
blockDim
,
0
,
ctx
.
stream
()
>>>
(
rows
,
cols
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
,
increment
,
epsilon
,
src
,
residual
,
bias
,
scale
,
layernorm_bias
,
mask_data
,
dst
,
layernorm_dst
,
mean
,
var
);
if
(
can_call_1024_kernel
)
{
const
int
WARPS_M
=
4
;
const
int
WARPS_N
=
1
;
const
int
THREADS_PER_WARP
=
32
;
const
int
BYTES_PER_LDG
=
16
;
const
int
VecSize
=
BYTES_PER_LDG
/
sizeof
(
T
);
const
int
THREADS_PER_CTA
=
WARPS_N
*
THREADS_PER_WARP
*
WARPS_M
;
const
int
ROWS_PER_CTA
=
WARPS_M
;
// Note: the grid can not exceed max_grid of the gpu.
const
int
grid
=
static_cast
<
int
>
(
std
::
ceil
(
rows
/
static_cast
<
float
>
(
ROWS_PER_CTA
)));
fused_ln_fwd_1024_kernel
<
T
,
U
,
LayerNormScaleBiasT
<
T
,
U
,
ScaleBiasWithSameTypeX
>
,
uint8_t
,
VecSize
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
><<<
grid
,
THREADS_PER_CTA
,
0
,
ctx
.
stream
()
>>>
(
rows
,
cols
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
,
increment
,
epsilon
,
src
,
residual
,
scale
,
layernorm_bias
,
mask_data
,
mean
,
var
,
dst
,
layernorm_dst
);
}
else
{
int
blockDim
=
GetDesiredBlockDim
(
cols
/
VecSize
);
FusedLayernormResidualDropoutBias
<
T
,
uint8_t
,
VecSize
,
U
,
ScaleBiasWithSameTypeX
><<<
rows
,
blockDim
,
0
,
ctx
.
stream
()
>>>
(
rows
,
cols
,
seed
,
dropout_prob
,
is_upscale_in_train
,
is_test
,
increment
,
epsilon
,
src
,
residual
,
bias
,
scale
,
layernorm_bias
,
mask_data
,
dst
,
layernorm_dst
,
mean
,
var
);
}
}
}
...
...
paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu
浏览文件 @
01d04be6
...
...
@@ -66,12 +66,10 @@ struct TestFusedLayernormResidualDropoutBias {
ctx
=
reinterpret_cast
<
platform
::
CUDADeviceContext
*>
(
devicectx
);
}
TestFusedLayernormResidualDropoutBias
(
int
_rows
,
int
_cols
,
uint64_t
_seed
=
0
,
float
_dropout_prob
=
0.0
,
float
_epsilon
=
0.00001
f
,
bool
_is_upscale_in_train
=
false
,
bool
_is_test
=
false
)
{
TestFusedLayernormResidualDropoutBias
(
int
_rows
,
int
_cols
,
uint64_t
_seed
=
0
,
float
_dropout_prob
=
0.0
,
float
_epsilon
=
0.00001
f
,
bool
_is_upscale_in_train
=
false
,
bool
_is_test
=
false
,
bool
_has_bias
=
true
)
{
rows
=
_rows
;
cols
=
_cols
;
seed
=
_seed
;
...
...
@@ -79,7 +77,7 @@ struct TestFusedLayernormResidualDropoutBias {
epsilon
=
_epsilon
;
is_upscale_in_train
=
_is_upscale_in_train
;
is_test
=
_is_test
;
has_bias
=
true
;
has_bias
=
_has_bias
;
has_scale
=
true
;
has_layernorm_bias
=
true
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
...
@@ -283,7 +281,6 @@ static void BaseTest(const bool is_fp16 = false) {
}
}
}
TEST
(
FusedDropout
,
GPUFusedLayernormResidualDropoutBias
)
{
BaseTest
<
float
>
();
}
TEST
(
FusedDropout
,
GPUFusedLayernormResidualDropoutBiasDouble
)
{
...
...
@@ -330,3 +327,12 @@ TEST(FusedDropout, GPUFusedLayernormResidualDropoutLargeShape) {
test
.
Run
();
test
.
CheckOut
(
static_cast
<
float
>
(
1e-4
));
}
TEST
(
FusedDropout
,
GPUFusedLayernormResidualDropoutFp16MLperf
)
{
const
int
rows
=
512
;
const
int
cols
=
1024
;
TestFusedLayernormResidualDropoutBias
<
platform
::
float16
>
test
(
rows
,
cols
,
0
,
0
,
0.00001
f
,
false
,
false
,
false
);
test
.
Run
();
test
.
CheckOut
(
static_cast
<
platform
::
float16
>
(
1e-2
));
}
paddle/fluid/operators/layer_norm_kernel.cu.h
浏览文件 @
01d04be6
...
...
@@ -23,6 +23,7 @@ namespace cub = hipcub;
#endif
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
...
...
@@ -35,6 +36,8 @@ using CudnnDataType = platform::CudnnDataType<T>;
template
<
typename
T
>
using
LayerNormParamType
=
typename
CudnnDataType
<
T
>::
BatchNormParamType
;
#define LN_NUM_COLS 1024
inline
static
int
GetDesiredBlockDim
(
int64_t
block_dim
)
{
#ifdef __HIPCC__
const
int
kMaxBlockDim
=
256
;
...
...
@@ -169,6 +172,118 @@ __inline__ __device__ half rsqrt_(const half val) {
}
#endif
#ifdef PADDLE_WITH_CUDA
template
<
typename
T
,
typename
U
,
typename
ScaleT
=
U
,
int
VecSize
=
8
,
int
WARPS_M
=
4
,
int
WARPS_N
=
1
,
int
BYTES_PER_LDG
=
16
,
int
ELTS_PER_ROW
=
1024
,
int
THREADS_PER_WARP
=
32
,
int
THREADS_PER_ROW
=
WARPS_N
*
THREADS_PER_WARP
,
int
THREADS_PER_CTA
=
WARPS_M
*
THREADS_PER_ROW
,
int
ROWS_PER_CTA
=
WARPS_M
,
int
ELTS_PER_ROW_PER_CTA
=
THREADS_PER_ROW
*
VecSize
,
int
LDGS
=
ELTS_PER_ROW
/
ELTS_PER_ROW_PER_CTA
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
)
void
ln_fwd_1024_kernel
(
int
rows
,
int
cols
,
const
float
epsilon
,
const
T
*
__restrict__
x_ptr
,
const
ScaleT
*
__restrict__
gamma_ptr
,
const
ScaleT
*
__restrict__
beta_ptr
,
U
*
__restrict__
mean_out_ptr
,
U
*
__restrict__
var_out_ptr
,
T
*
__restrict__
y_ptr
)
{
using
Vec
=
platform
::
AlignedVector
<
T
,
VecSize
>
;
using
Vec_scale
=
platform
::
AlignedVector
<
ScaleT
,
VecSize
>
;
const
int
tidx
=
threadIdx
.
x
;
const
int
bidx
=
blockIdx
.
x
;
const
int
lane
=
tidx
%
THREADS_PER_WARP
;
// 0, 1, ..., 31
const
int
warp
=
tidx
/
THREADS_PER_WARP
;
// 0, 1, 2, 3
const
int
warp_n
=
warp
%
WARPS_N
;
// 0
const
int
warp_m
=
warp
/
WARPS_N
;
// 0, 1, 2, 3
const
int
c
=
warp_n
*
THREADS_PER_WARP
+
lane
;
// lane
const
int
r
=
bidx
*
ROWS_PER_CTA
+
warp_m
;
// row id
Vec_scale
gamma
[
LDGS
];
Vec_scale
beta
[
LDGS
];
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
platform
::
Load
<
ScaleT
,
VecSize
>
(
gamma_ptr
+
col
*
VecSize
,
&
gamma
[
it
]);
platform
::
Load
<
ScaleT
,
VecSize
>
(
beta_ptr
+
col
*
VecSize
,
&
beta
[
it
]);
col
+=
THREADS_PER_ROW
;
}
constexpr
U
rn
=
1.
f
/
U
(
LN_NUM_COLS
);
for
(
int
row
=
r
;
row
<
rows
;
row
+=
gridDim
.
x
*
ROWS_PER_CTA
)
{
Vec
x
[
LDGS
];
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
platform
::
Load
<
T
,
VecSize
>
(
x_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
,
&
x
[
it
]);
col
+=
THREADS_PER_ROW
;
}
U
xf
[
LDGS
*
VecSize
];
U
mu_local
=
0.
f
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
VecSize
;
jt
++
)
{
xf
[
it
*
VecSize
+
jt
]
=
U
(
x
[
it
][
jt
]);
mu_local
+=
xf
[
it
*
VecSize
+
jt
];
}
}
#pragma unroll
for
(
int
it
=
1
;
it
<
THREADS_PER_WARP
;
it
*=
2
)
{
mu_local
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
mu_local
,
it
);
}
mu_local
*=
rn
;
if
(
lane
==
0
)
{
mean_out_ptr
[
row
]
=
mu_local
;
}
U
var_local
=
0.
f
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
VecSize
;
jt
++
)
{
U
diff
=
xf
[
it
*
VecSize
+
jt
]
-
mu_local
;
var_local
+=
diff
*
diff
;
}
}
#pragma unroll
for
(
int
it
=
1
;
it
<
THREADS_PER_WARP
;
it
*=
2
)
{
var_local
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
var_local
,
it
);
}
// Note: to assure if it is right for double
U
rsigma
=
rsqrtf
(
var_local
*
rn
+
epsilon
);
if
(
lane
==
0
)
{
var_out_ptr
[
row
]
=
var_local
*
rn
;
}
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
for
(
int
jt
=
0
;
jt
<
VecSize
;
jt
++
)
{
// use fp16 to compute
// ScaleT tmp = static_cast<ScaleT>(rsigma * (xf[it * VecSize + jt] -
// mu_local));
// x[it][jt] = gamma[it][jt] * tmp + beta[it][jt];
// cast to fp32 to compute
U
tmp
=
(
rsigma
*
(
static_cast
<
U
>
(
xf
[
it
*
VecSize
+
jt
])
-
mu_local
));
x
[
it
][
jt
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
gamma
[
it
][
jt
])
*
tmp
+
static_cast
<
U
>
(
beta
[
it
][
jt
]));
}
}
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
platform
::
Store
<
T
,
VecSize
>
(
x
[
it
],
y_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
);
col
+=
THREADS_PER_ROW
;
}
}
}
#endif
template
<
typename
T
,
typename
U
,
bool
ScaleBiasWithSameTypeX
>
using
LayerNormScaleBiasT
=
typename
std
::
conditional
<
ScaleBiasWithSameTypeX
,
T
,
U
>::
type
;
...
...
paddle/fluid/operators/layer_norm_op.cu
浏览文件 @
01d04be6
...
...
@@ -112,11 +112,49 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
} \
} while (0)
if
(
is_scale_bias_same_dtype_with_x
)
{
PADDLE_LAUNCH_LAYERNORM_FWD
(
T
,
true
);
#ifdef PADDLE_WITH_CUDA
bool
can_call_1024_kernel
=
false
;
if
(
feature_size
==
1024
&&
scale
!=
nullptr
&&
bias
!=
nullptr
)
{
can_call_1024_kernel
=
true
;
}
if
(
can_call_1024_kernel
)
{
const
int
WARPS_M
=
4
;
const
int
WARPS_N
=
1
;
const
int
THREADS_PER_WARP
=
32
;
const
int
BYTES_PER_LDG
=
16
;
const
int
VecSize
=
BYTES_PER_LDG
/
sizeof
(
T
);
const
int
THREADS_PER_CTA
=
WARPS_N
*
THREADS_PER_WARP
*
WARPS_M
;
const
int
ROWS_PER_CTA
=
WARPS_M
;
const
int
grid
=
static_cast
<
int
>
(
std
::
ceil
(
batch_size
/
static_cast
<
float
>
(
ROWS_PER_CTA
)));
if
(
is_scale_bias_same_dtype_with_x
)
{
ln_fwd_1024_kernel
<
T
,
U
,
T
,
VecSize
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
><<<
grid
,
THREADS_PER_CTA
,
0
,
stream
>>>
(
batch_size
,
feature_size
,
epsilon
,
x_data
,
static_cast
<
const
T
*>
(
void_scale_data
),
static_cast
<
const
T
*>
(
void_bias_data
),
mean_data
,
var_data
,
y_data
);
}
else
{
ln_fwd_1024_kernel
<
T
,
U
,
U
,
VecSize
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
><<<
grid
,
THREADS_PER_CTA
,
0
,
stream
>>>
(
batch_size
,
feature_size
,
epsilon
,
x_data
,
static_cast
<
const
U
*>
(
void_scale_data
),
static_cast
<
const
U
*>
(
void_bias_data
),
mean_data
,
var_data
,
y_data
);
}
}
else
{
PADDLE_LAUNCH_LAYERNORM_FWD
(
U
,
false
);
#endif
if
(
is_scale_bias_same_dtype_with_x
)
{
PADDLE_LAUNCH_LAYERNORM_FWD
(
T
,
true
);
}
else
{
PADDLE_LAUNCH_LAYERNORM_FWD
(
U
,
false
);
}
#ifdef PADDLE_WITH_CUDA
}
#endif
#undef PADDLE_LAUNCH_LAYERNORM_FWD
}
};
...
...
python/paddle/fluid/tests/unittests/test_layer_norm_op.py
浏览文件 @
01d04be6
...
...
@@ -278,6 +278,8 @@ class TestLayerNormOp(unittest.TestCase):
has_scale
=
False
,
has_bias
=
False
,
y_grad_scale
=
0.1
)
self
.
check_forward_backward
(
shape
=
[
512
,
1024
],
begin_norm_axis
=
1
,
has_scale
=
True
,
has_bias
=
True
)
class
TestLayerNormAPI
(
unittest
.
TestCase
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录