Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6cdc18af
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6cdc18af
编写于
11月 11, 2022
作者:
Z
zhangkaihuo
提交者:
GitHub
11月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Sparse]Optimize BatchNorm1D forward in test mode (#47736)
上级
1ad95e97
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
72 addition
and
13 deletion
+72
-13
paddle/phi/kernels/gpu/batch_norm_kernel.cu
paddle/phi/kernels/gpu/batch_norm_kernel.cu
+72
-13
未找到文件。
paddle/phi/kernels/gpu/batch_norm_kernel.cu
浏览文件 @
6cdc18af
...
@@ -72,6 +72,40 @@ static __global__ void BNForwardInference(const T *x,
...
@@ -72,6 +72,40 @@ static __global__ void BNForwardInference(const T *x,
}
}
}
}
template
<
typename
T
>
static
__global__
void
InverseVariance
(
const
BatchNormParamType
<
T
>
*
variance
,
const
double
epsilon
,
const
int
C
,
BatchNormParamType
<
T
>
*
inv_variance
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
<
C
)
{
inv_variance
[
tid
]
=
1
/
sqrt
(
variance
[
tid
]
+
epsilon
);
}
}
template
<
typename
T
,
phi
::
DataLayout
layout
>
static
__global__
void
BN1DForwardInference
(
const
T
*
x
,
const
BatchNormParamType
<
T
>
*
mean
,
const
BatchNormParamType
<
T
>
*
inv_variance
,
const
BatchNormParamType
<
T
>
*
scale
,
const
BatchNormParamType
<
T
>
*
bias
,
const
int
C
,
const
int
N
,
const
int
HxW
,
const
double
epsilon
,
T
*
y
)
{
int
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
num
=
N
*
C
*
HxW
;
for
(
int
i
=
gid
;
i
<
num
;
i
+=
stride
)
{
const
int
c
=
layout
==
phi
::
DataLayout
::
kNCHW
?
i
/
HxW
%
C
:
i
%
C
;
BatchNormParamType
<
T
>
x_sub_mean
=
static_cast
<
BatchNormParamType
<
T
>>
(
x
[
i
])
-
mean
[
c
];
y
[
i
]
=
static_cast
<
T
>
(
scale
[
c
]
*
x_sub_mean
*
inv_variance
[
c
]
+
bias
[
c
]);
}
}
template
<
typename
T
,
int
BlockDim
,
phi
::
DataLayout
layout
>
template
<
typename
T
,
int
BlockDim
,
phi
::
DataLayout
layout
>
static
__global__
LAUNCH_BOUNDS
(
BlockDim
)
void
BNForwardTraining
(
static
__global__
LAUNCH_BOUNDS
(
BlockDim
)
void
BNForwardTraining
(
const
T
*
x
,
const
T
*
x
,
...
@@ -795,7 +829,7 @@ void BatchNormKernel(const Context &ctx,
...
@@ -795,7 +829,7 @@ void BatchNormKernel(const Context &ctx,
// epsilon));
// epsilon));
#else
#else
const
bool
use_native_kernel
=
const
bool
use_native_kernel
=
(
(
x_dims
.
size
()
==
2
&&
N
>=
CUDNN_PER_ACTIVATION_THRESHOLD
)
||
(
x_dims
.
size
()
==
2
||
(
x_dims
.
size
()
==
3
&&
N
>=
CUDNN_SPATIAL_THRESHOLD
));
(
x_dims
.
size
()
==
3
&&
N
>=
CUDNN_SPATIAL_THRESHOLD
));
if
(
use_native_kernel
)
{
if
(
use_native_kernel
)
{
const
int
block_size
=
256
;
const
int
block_size
=
256
;
...
@@ -814,18 +848,43 @@ void BatchNormKernel(const Context &ctx,
...
@@ -814,18 +848,43 @@ void BatchNormKernel(const Context &ctx,
epsilon
,
epsilon
,
transformed_y
.
template
data
<
T
>());
transformed_y
.
template
data
<
T
>());
}
else
{
}
else
{
BNForwardInference
<
T
,
DataLayout
::
kNHWC
>
if
(
x_dims
.
size
()
==
2
)
{
<<<
grid_size
,
block_size
,
0
,
ctx
.
stream
()
>>>
(
DenseTensor
inv_var
=
phi
::
Empty
<
BatchNormParamType
<
T
>>
(
ctx
,
{
C
});
transformed_x
.
template
data
<
T
>(),
auto
*
inv_var_ptr
=
inv_var
.
data
<
BatchNormParamType
<
T
>>
();
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
const
int
threads
=
512
>
C
?
C
:
512
;
est_var
->
template
data
<
BatchNormParamType
<
T
>
>
(),
const
int
blocks
=
(
C
+
511
)
/
512
;
scale
.
template
data
<
BatchNormParamType
<
T
>
>
(),
InverseVariance
<
T
><<<
blocks
,
threads
>>>
(
bias
.
template
data
<
BatchNormParamType
<
T
>
>
(),
est_var
->
template
data
<
BatchNormParamType
<
T
>
>
(),
C
,
epsilon
,
N
,
C
,
H
*
W
*
D
,
inv_var_ptr
);
epsilon
,
BN1DForwardInference
<
T
,
DataLayout
::
kNHWC
>
transformed_y
.
template
data
<
T
>());
<<<
grid_size
,
block_size
,
0
,
ctx
.
stream
()
>>>
(
transformed_x
.
template
data
<
T
>(),
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
// est_var->template data<BatchNormParamType<T>>(),
inv_var_ptr
,
scale
.
template
data
<
BatchNormParamType
<
T
>
>
(),
bias
.
template
data
<
BatchNormParamType
<
T
>
>
(),
C
,
N
,
H
*
W
*
D
,
epsilon
,
transformed_y
.
template
data
<
T
>());
}
else
{
BNForwardInference
<
T
,
DataLayout
::
kNHWC
>
<<<
grid_size
,
block_size
,
0
,
ctx
.
stream
()
>>>
(
transformed_x
.
template
data
<
T
>(),
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
est_var
->
template
data
<
BatchNormParamType
<
T
>
>
(),
scale
.
template
data
<
BatchNormParamType
<
T
>
>
(),
bias
.
template
data
<
BatchNormParamType
<
T
>
>
(),
C
,
N
,
H
*
W
*
D
,
epsilon
,
transformed_y
.
template
data
<
T
>());
}
}
}
}
else
{
}
else
{
PADDLE_ENFORCE_GPU_SUCCESS
(
PADDLE_ENFORCE_GPU_SUCCESS
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录