Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
85baa3c0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
85baa3c0
编写于
6月 02, 2022
作者:
L
Li Min
提交者:
GitHub
6月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Extend forward fast layer_norm kernel to support more dimensions. (#43118)
* extend forward fast_ln_kernel to support more column values.
上级
8c7cb3d6
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
135 addition
and
79 deletion
+135
-79
paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h
...d/operators/fused/fused_layernorm_residual_dropout_bias.h
+6
-2
paddle/fluid/operators/layer_norm_kernel.cu.h
paddle/fluid/operators/layer_norm_kernel.cu.h
+62
-27
paddle/phi/kernels/gpu/layer_norm_kernel.cu
paddle/phi/kernels/gpu/layer_norm_kernel.cu
+67
-50
未找到文件。
paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h
浏览文件 @
85baa3c0
...
...
@@ -478,11 +478,15 @@ void LaunchLayernormResidualDropoutBias(
#define LAUNCH_FUSED_FAST_LN_KERNEL \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(768); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1024); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1280); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1536); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1792); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(2048); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(4096)
bool
can_call_fast_ln_kernel
=
false
;
if
((
cols
==
768
||
cols
==
1024
||
cols
==
4096
)
&&
scale
!=
nullptr
&&
layernorm_bias
!=
nullptr
)
{
if
((
(
cols
>=
768
&&
cols
<=
2048
&&
cols
%
256
==
0
)
||
cols
==
4096
)
&&
scale
!=
nullptr
&&
layernorm_bias
!=
nullptr
)
{
can_call_fast_ln_kernel
=
true
;
}
VLOG
(
6
)
<<
"can_call_fast_ln_kernel = "
<<
can_call_fast_ln_kernel
;
...
...
paddle/fluid/operators/layer_norm_kernel.cu.h
浏览文件 @
85baa3c0
...
...
@@ -36,8 +36,6 @@ using CudnnDataType = platform::CudnnDataType<T>;
template
<
typename
T
>
using
LayerNormParamType
=
typename
CudnnDataType
<
T
>::
BatchNormParamType
;
#define LN_NUM_COLS 1024
inline
static
int
GetDesiredBlockDim
(
int64_t
block_dim
)
{
#ifdef __HIPCC__
const
int
kMaxBlockDim
=
256
;
...
...
@@ -183,11 +181,12 @@ template <typename T, typename U, typename ScaleT = U, int VecSize = 8,
int
ROWS_PER_CTA
=
WARPS_M
,
int
ELTS_PER_ROW_PER_CTA
=
THREADS_PER_ROW
*
VecSize
,
int
LDGS
=
ELTS_PER_ROW
/
ELTS_PER_ROW_PER_CTA
>
__global__
__launch_bounds__
(
THREADS_PER_CTA
)
void
ln_fwd_1024
_kernel
(
__global__
__launch_bounds__
(
THREADS_PER_CTA
)
void
fast_ln_fwd
_kernel
(
int
rows
,
int
cols
,
const
float
epsilon
,
const
T
*
__restrict__
x_ptr
,
const
ScaleT
*
__restrict__
gamma_ptr
,
const
ScaleT
*
__restrict__
beta_ptr
,
U
*
__restrict__
mean_out_ptr
,
U
*
__restrict__
var_out_ptr
,
T
*
__restrict__
y_ptr
)
{
__shared__
U
smem
[
WARPS_M
*
WARPS_N
];
using
Vec
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
using
Vec_scale
=
phi
::
AlignedVector
<
ScaleT
,
VecSize
>
;
...
...
@@ -210,12 +209,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
col
+=
THREADS_PER_ROW
;
}
constexpr
U
rn
=
1.
f
/
U
(
LN_NUM_COLS
);
constexpr
U
rn
=
1.
f
/
U
(
ELTS_PER_ROW
);
for
(
int
row
=
r
;
row
<
rows
;
row
+=
gridDim
.
x
*
ROWS_PER_CTA
)
{
Vec
x
[
LDGS
];
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
phi
::
Load
<
T
,
VecSize
>
(
x_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
,
&
x
[
it
]);
phi
::
Load
<
T
,
VecSize
>
(
x_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
,
&
x
[
it
]);
col
+=
THREADS_PER_ROW
;
}
U
xf
[
LDGS
*
VecSize
];
...
...
@@ -235,6 +234,23 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
for
(
int
it
=
1
;
it
<
THREADS_PER_WARP
;
it
*=
2
)
{
mu_local
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
mu_local
,
it
);
}
if
(
WARPS_N
>
1
)
{
if
(
lane
==
0
)
{
smem
[
warp_m
*
WARPS_N
+
warp_n
]
=
mu_local
;
}
__syncthreads
();
if
(
tidx
==
0
)
{
mu_local
=
0.
f
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARPS_N
;
++
it
)
{
mu_local
+=
smem
[
warp_m
*
WARPS_N
+
it
];
}
smem
[
warp_m
]
=
mu_local
;
}
__syncthreads
();
mu_local
=
smem
[
warp_m
];
}
mu_local
*=
rn
;
if
(
lane
==
0
)
{
mean_out_ptr
[
row
]
=
mu_local
;
...
...
@@ -254,6 +270,24 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
for
(
int
it
=
1
;
it
<
THREADS_PER_WARP
;
it
*=
2
)
{
var_local
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
var_local
,
it
);
}
if
(
WARPS_N
>
1
)
{
if
(
lane
==
0
)
{
smem
[
warp_m
*
WARPS_N
+
warp_n
]
=
var_local
;
}
__syncthreads
();
if
(
tidx
==
0
)
{
var_local
=
0.
f
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARPS_N
;
++
it
)
{
var_local
+=
smem
[
warp_m
*
WARPS_N
+
it
];
}
smem
[
warp_m
]
=
var_local
;
}
__syncthreads
();
var_local
=
smem
[
warp_m
];
}
// Note: to assure if it is right for double
U
rsigma
=
rsqrtf
(
var_local
*
rn
+
epsilon
);
if
(
lane
==
0
)
{
...
...
@@ -277,7 +311,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_fwd_1024_kernel(
#pragma unroll
for
(
int
it
=
0
,
col
=
c
;
it
<
LDGS
;
it
++
)
{
phi
::
Store
<
T
,
VecSize
>
(
x
[
it
],
y_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
);
phi
::
Store
<
T
,
VecSize
>
(
x
[
it
],
y_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
);
col
+=
THREADS_PER_ROW
;
}
}
...
...
@@ -416,10 +450,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
const
int
r
=
bidx
*
ROWS_PER_CTA
+
warp_m
;
const
int
c
=
warp_n
*
THREADS_PER_WARP
+
lane
;
static_assert
(
LN_NUM_COLS
==
THREADS_PER_ROW
*
LDGS
*
VecSize
,
""
);
static_assert
(
ELTS_PER_ROW
==
THREADS_PER_ROW
*
LDGS
*
VecSize
,
""
);
// smem for column reduction
__shared__
U
smem_
[
ROWS_PER_CTA
*
LN_NUM_COLS
];
__shared__
U
smem_
[
ROWS_PER_CTA
*
ELTS_PER_ROW
];
U
dgamma_sum
[
LDGS
*
VecSize
];
U
dbeta_sum
[
LDGS
*
VecSize
];
...
...
@@ -434,7 +468,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
U
*
sum_loss2_shared
=
&
smem_sum_loss2
[
warp_m
*
WARPS_N
];
// step-1: compute dx and local results of dscale and dbias
constexpr
float
rn
=
1.
f
/
static_cast
<
float
>
(
LN_NUM_COLS
);
constexpr
float
rn
=
1.
f
/
static_cast
<
float
>
(
ELTS_PER_ROW
);
Vec_scale
gamma
[
LDGS
];
int
col
=
c
;
#pragma unroll
...
...
@@ -452,12 +486,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
int
col
=
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
phi
::
Load
<
T
,
VecSize
>
(
dout_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
,
phi
::
Load
<
T
,
VecSize
>
(
dout_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
,
&
dout
[
it
]);
phi
::
Load
<
T
,
VecSize
>
(
x_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
,
&
x
[
it
]);
phi
::
Load
<
T
,
VecSize
>
(
x_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
,
&
x
[
it
]);
if
(
isFusedDropoutResidualLn
)
{
phi
::
Load
<
MaskType
,
VecSize
>
(
mask_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
,
&
mask_vec
[
it
]);
mask_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
,
&
mask_vec
[
it
]);
}
col
+=
THREADS_PER_ROW
;
...
...
@@ -551,10 +585,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
col
=
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
phi
::
Store
<
T
,
VecSize
>
(
x
[
it
],
dx_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
);
phi
::
Store
<
T
,
VecSize
>
(
x
[
it
],
dx_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
);
if
(
isFusedDropoutResidualLn
)
{
phi
::
Store
<
T
,
VecSize
>
(
dout
[
it
],
d_dropout_src_ptr
+
row
*
LN_NUM_COLS
+
col
*
VecSize
);
dout
[
it
],
d_dropout_src_ptr
+
row
*
ELTS_PER_ROW
+
col
*
VecSize
);
}
col
+=
THREADS_PER_ROW
;
}
...
...
@@ -562,12 +597,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
// step-2: column reduction of dscale and dbias for each thread block.
// each block's sum: [4 * 1024] -> [1 * 1024]
enum
{
NUM_RES
=
LN_NUM_COLS
/
THREADS_PER_CTA
};
// 1024/128 = 8
static_assert
(
NUM_RES
*
THREADS_PER_CTA
==
LN_NUM_COLS
,
""
);
enum
{
NUM_RES
=
ELTS_PER_ROW
/
THREADS_PER_CTA
};
// 1024/128 = 8
static_assert
(
NUM_RES
*
THREADS_PER_CTA
==
ELTS_PER_ROW
,
""
);
U
*
smem_write
;
smem_write
=
&
smem_
[
warp_m
*
LN_NUM_COLS
+
tid_r
*
VecSize
];
// [4 * 1024]
smem_write
=
&
smem_
[
warp_m
*
ELTS_PER_ROW
+
tid_r
*
VecSize
];
// [4 * 1024]
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
...
...
@@ -583,12 +618,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
for
(
int
it
=
0
;
it
<
ROWS_PER_CTA
;
it
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
cta_dbeta_sum
[
jt
]
+=
smem_
[
it
*
LN_NUM_COLS
+
tidx
+
jt
*
THREADS_PER_CTA
];
smem_
[
it
*
ELTS_PER_ROW
+
tidx
+
jt
*
THREADS_PER_CTA
];
}
}
__syncthreads
();
smem_write
=
&
smem_
[
warp_m
*
LN_NUM_COLS
+
tid_r
*
VecSize
];
smem_write
=
&
smem_
[
warp_m
*
ELTS_PER_ROW
+
tid_r
*
VecSize
];
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
#pragma unroll
...
...
@@ -603,19 +638,19 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel(
for
(
int
it
=
0
;
it
<
ROWS_PER_CTA
;
it
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
cta_dgamma_sum
[
jt
]
+=
smem_
[
it
*
LN_NUM_COLS
+
tidx
+
jt
*
THREADS_PER_CTA
];
smem_
[
it
*
ELTS_PER_ROW
+
tidx
+
jt
*
THREADS_PER_CTA
];
}
}
// the shape of results:(#blocks, 1024)
U
*
dgamma_part
=
static_cast
<
U
*>
(
dgamma_temp_ptr
)
+
bidx
*
LN_NUM_COLS
+
tidx
;
static_cast
<
U
*>
(
dgamma_temp_ptr
)
+
bidx
*
ELTS_PER_ROW
+
tidx
;
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
*
dgamma_part
=
cta_dgamma_sum
[
jt
];
dgamma_part
+=
THREADS_PER_CTA
;
}
U
*
dbeta_part
=
static_cast
<
U
*>
(
dbeta_temp_ptr
)
+
bidx
*
LN_NUM_COLS
+
tidx
;
U
*
dbeta_part
=
static_cast
<
U
*>
(
dbeta_temp_ptr
)
+
bidx
*
ELTS_PER_ROW
+
tidx
;
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
*
dbeta_part
=
cta_dbeta_sum
[
jt
];
dbeta_part
+=
THREADS_PER_CTA
;
...
...
@@ -640,7 +675,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
const
int
rows
,
U
*
__restrict__
dg_part_
,
U
*
__restrict__
db_part_
,
ScaleT
*
__restrict__
dg_
,
ScaleT
*
__restrict__
db_
)
{
using
Vec
=
phi
::
AlignedVector
<
U
,
VecSize
>
;
static_assert
(
VEC_COLS
==
LN_NUM_COLS
/
VecSize
,
""
);
static_assert
(
VEC_COLS
==
ELTS_PER_ROW
/
VecSize
,
""
);
const
int
tidx
=
threadIdx
.
x
;
const
int
bidx
=
blockIdx
.
x
;
...
...
@@ -656,8 +691,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
__shared__
U
smem_space
[(
WARPS_M
-
1
)
*
THREADS_PER_ROW
*
VecSize
];
for
(
int
col
=
c
;
col
<
VEC_COLS
;
col
+=
gridDim
.
x
*
THREADS_PER_ROW
)
{
const
U
*
dg_part_ptr
=
(
dg_part_
)
+
r
*
LN_NUM_COLS
+
col
*
VecSize
;
const
U
*
db_part_ptr
=
(
db_part_
)
+
r
*
LN_NUM_COLS
+
col
*
VecSize
;
const
U
*
dg_part_ptr
=
(
dg_part_
)
+
r
*
ELTS_PER_ROW
+
col
*
VecSize
;
const
U
*
db_part_ptr
=
(
db_part_
)
+
r
*
ELTS_PER_ROW
+
col
*
VecSize
;
U
dg_sum
[
VecSize
];
U
db_sum
[
VecSize
];
...
...
@@ -669,8 +704,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
Vec
db
;
phi
::
Load
<
U
,
VecSize
>
(
dg_part_ptr
,
&
dg
);
phi
::
Load
<
U
,
VecSize
>
(
db_part_ptr
,
&
db
);
dg_part_ptr
+=
ROWS_PER_CTA
*
LN_NUM_COLS
;
db_part_ptr
+=
ROWS_PER_CTA
*
LN_NUM_COLS
;
dg_part_ptr
+=
ROWS_PER_CTA
*
ELTS_PER_ROW
;
db_part_ptr
+=
ROWS_PER_CTA
*
ELTS_PER_ROW
;
#pragma unroll
for
(
int
jt
=
0
;
jt
<
VecSize
;
jt
++
)
{
...
...
paddle/phi/kernels/gpu/layer_norm_kernel.cu
浏览文件 @
85baa3c0
...
...
@@ -84,7 +84,7 @@ void LayerNormKernel(const Context &dev_ctx,
PADDLE_ENFORCE_EQ
(
scale
->
dtype
(),
bias
->
dtype
(),
phi
::
errors
::
InvalidArgument
(
"Thi
e
Scale and Bias of layer_norm op "
phi
::
errors
::
InvalidArgument
(
"Thi
s
Scale and Bias of layer_norm op "
"should have the same data type."
));
}
}
else
{
...
...
@@ -131,59 +131,75 @@ void LayerNormKernel(const Context &dev_ctx,
} \
} while (0)
#define PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, feature_size) \
case (feature_size): { \
constexpr int WARPS_N = feature_size < 1024 ? 1 : (feature_size / 1024); \
constexpr int WARPS_M = 4 / WARPS_N; \
const int THREADS_PER_WARP = 32; \
const int BYTES_PER_LDG = 16; \
const int VecSize = BYTES_PER_LDG / sizeof(T); \
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; \
const int ROWS_PER_CTA = WARPS_M; \
const int grid = static_cast<int>( \
std::ceil(batch_size / static_cast<float>(ROWS_PER_CTA))); \
paddle::operators::fast_ln_fwd_kernel< \
T, \
U, \
ScaleT, \
VecSize, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>( \
batch_size, \
feature_size, \
epsilon, \
x_data, \
static_cast<const ScaleT *>(void_scale_data), \
static_cast<const ScaleT *>(void_bias_data), \
mean_data, \
var_data, \
y_data); \
} break
#define PADDLE_LAUNCH_FAST_LAYERNORM_FWD(ScaleT) \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 768); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1024); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1280); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1536); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 1792); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 2048); \
PADDLE_LAUNCH_FAST_LAYERNORM_FWD_BASE(ScaleT, 4096)
#ifdef PADDLE_WITH_CUDA
bool
can_call_1024_kernel
=
false
;
if
(
feature_size
==
1024
&&
scale
!=
nullptr
&&
bias
!=
nullptr
)
{
can_call_1024_kernel
=
true
;
bool
can_call_fast_kernel
=
false
;
if
((
feature_size
>=
768
&&
feature_size
<=
2048
&&
feature_size
%
256
==
0
||
feature_size
==
4096
)
&&
scale
!=
nullptr
&&
bias
!=
nullptr
)
{
// can_call_fast_kernel = true;
can_call_fast_kernel
=
false
;
}
if
(
can_call_1024_kernel
)
{
const
int
WARPS_M
=
4
;
const
int
WARPS_N
=
1
;
const
int
THREADS_PER_WARP
=
32
;
const
int
BYTES_PER_LDG
=
16
;
const
int
VecSize
=
BYTES_PER_LDG
/
sizeof
(
T
);
const
int
THREADS_PER_CTA
=
WARPS_N
*
THREADS_PER_WARP
*
WARPS_M
;
const
int
ROWS_PER_CTA
=
WARPS_M
;
const
int
grid
=
static_cast
<
int
>
(
std
::
ceil
(
batch_size
/
static_cast
<
float
>
(
ROWS_PER_CTA
)));
if
(
can_call_fast_kernel
)
{
if
(
is_scale_bias_same_dtype_with_x
)
{
paddle
::
operators
::
ln_fwd_1024_kernel
<
T
,
U
,
T
,
VecSize
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
><<<
grid
,
THREADS_PER_CTA
,
0
,
stream
>>>
(
batch_size
,
feature_size
,
epsilon
,
x_data
,
static_cast
<
const
T
*>
(
void_scale_data
),
static_cast
<
const
T
*>
(
void_bias_data
),
mean_data
,
var_data
,
y_data
);
switch
(
feature_size
)
{
PADDLE_LAUNCH_FAST_LAYERNORM_FWD
(
T
);
default:
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"Only when feature_size is from 256 to 4096 and is diviaible by "
"256 is supported "
"now"
));
break
;
}
}
else
{
paddle
::
operators
::
ln_fwd_1024_kernel
<
T
,
U
,
U
,
VecSize
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
><<<
grid
,
THREADS_PER_CTA
,
0
,
stream
>>>
(
batch_size
,
feature_size
,
epsilon
,
x_data
,
static_cast
<
const
U
*>
(
void_scale_data
),
static_cast
<
const
U
*>
(
void_bias_data
),
mean_data
,
var_data
,
y_data
);
switch
(
feature_size
)
{
PADDLE_LAUNCH_FAST_LAYERNORM_FWD
(
U
);
default:
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"Only when feature_size is from 256 to 4096 and is diviaible by "
"is supported "
"now"
));
break
;
}
}
}
else
{
#endif
...
...
@@ -197,6 +213,7 @@ void LayerNormKernel(const Context &dev_ctx,
#endif
#undef PADDLE_LAUNCH_LAYERNORM_FWD
#undef PADDLE_LAUNCH_FAST_LAYERNORM_FWD
}
}
// namespace phi
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录