Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
7879477f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7879477f
编写于
4月 23, 2021
作者:
R
ronnywang
提交者:
GitHub
4月 23, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ROCM] add cuda kenrel for batch_norm_op (#32393)
上级
49773f36
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
409 addition
and
140 deletion
+409
-140
paddle/fluid/operators/batch_norm_op.cu
paddle/fluid/operators/batch_norm_op.cu
+386
-120
paddle/fluid/operators/norm_utils.cu.h
paddle/fluid/operators/norm_utils.cu.h
+23
-20
未找到文件。
paddle/fluid/operators/batch_norm_op.cu
浏览文件 @
7879477f
...
...
@@ -41,6 +41,83 @@ using CudnnDataType = platform::CudnnDataType<T>;
template
<
typename
T
>
using
BatchNormParamType
=
typename
CudnnDataType
<
T
>::
BatchNormParamType
;
template
<
typename
T
,
framework
::
DataLayout
layout
>
static
__global__
void
BNForwardInference
(
const
T
*
x
,
const
BatchNormParamType
<
T
>
*
mean
,
const
BatchNormParamType
<
T
>
*
variance
,
const
BatchNormParamType
<
T
>
*
scale
,
const
BatchNormParamType
<
T
>
*
bias
,
const
int
C
,
const
int
N
,
const
int
HxW
,
const
double
epsilon
,
T
*
y
)
{
int
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
int
num
=
N
*
C
*
HxW
;
for
(
int
i
=
gid
;
i
<
num
;
i
+=
stride
)
{
const
int
c
=
layout
==
framework
::
DataLayout
::
kNCHW
?
i
/
HxW
%
C
:
i
%
C
;
BatchNormParamType
<
T
>
x_sub_mean
=
static_cast
<
BatchNormParamType
<
T
>>
(
x
[
i
])
-
mean
[
c
];
BatchNormParamType
<
T
>
inv_var
=
1
/
sqrt
(
variance
[
c
]
+
epsilon
);
y
[
i
]
=
static_cast
<
T
>
(
scale
[
c
]
*
x_sub_mean
*
inv_var
+
bias
[
c
]);
}
}
template
<
typename
T
,
int
BlockDim
,
framework
::
DataLayout
layout
>
static
__global__
LAUNCH_BOUNDS
(
BlockDim
)
void
BNForwardTraining
(
const
T
*
x
,
const
BatchNormParamType
<
T
>
*
scale
,
const
BatchNormParamType
<
T
>
*
bias
,
const
int
C
,
const
int
N
,
const
int
HxW
,
const
double
epsilon
,
double
exponentialAverageFactor
,
T
*
y
,
BatchNormParamType
<
T
>
*
mean
,
BatchNormParamType
<
T
>
*
variance
,
BatchNormParamType
<
T
>
*
save_mean
,
BatchNormParamType
<
T
>
*
save_inv_variance
)
{
int
outer_size
=
C
;
int
inner_size
=
N
*
HxW
;
typedef
cub
::
BlockReduce
<
BatchNormParamType
<
T
>
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
mean_storage
;
__shared__
typename
BlockReduce
::
TempStorage
variance_storeage
;
__shared__
BatchNormParamType
<
T
>
mean_val
;
__shared__
BatchNormParamType
<
T
>
variance_val
;
__shared__
BatchNormParamType
<
T
>
inv_var_val
;
for
(
int
i
=
blockIdx
.
x
;
i
<
outer_size
;
i
+=
gridDim
.
x
)
{
BatchNormParamType
<
T
>
x_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
BatchNormParamType
<
T
>
x_square_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
for
(
int
j
=
threadIdx
.
x
;
j
<
inner_size
;
j
+=
blockDim
.
x
)
{
const
int
index
=
layout
==
framework
::
DataLayout
::
kNCHW
?
(
j
/
HxW
*
C
+
i
)
*
HxW
+
j
%
HxW
:
j
*
outer_size
+
i
;
BatchNormParamType
<
T
>
x_i
=
static_cast
<
BatchNormParamType
<
T
>>
(
x
[
index
]);
x_sum
+=
x_i
;
x_square_sum
+=
x_i
*
x_i
;
}
x_sum
=
BlockReduce
(
mean_storage
).
Reduce
(
x_sum
,
cub
::
Sum
());
x_square_sum
=
BlockReduce
(
variance_storeage
).
Reduce
(
x_square_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
mean_val
=
x_sum
/
inner_size
;
variance_val
=
x_square_sum
/
inner_size
-
mean_val
*
mean_val
;
inv_var_val
=
1
/
sqrt
(
variance_val
+
epsilon
);
if
(
save_mean
&&
save_inv_variance
)
{
save_mean
[
i
]
=
mean_val
;
save_inv_variance
[
i
]
=
inv_var_val
;
}
mean
[
i
]
=
(
1
-
exponentialAverageFactor
)
*
mean_val
+
exponentialAverageFactor
*
mean
[
i
];
variance
[
i
]
=
(
1
-
exponentialAverageFactor
)
*
variance_val
+
exponentialAverageFactor
*
variance
[
i
];
}
__syncthreads
();
for
(
int
j
=
threadIdx
.
x
;
j
<
inner_size
;
j
+=
blockDim
.
x
)
{
const
int
index
=
layout
==
framework
::
DataLayout
::
kNCHW
?
(
j
/
HxW
*
C
+
i
)
*
HxW
+
j
%
HxW
:
j
*
outer_size
+
i
;
BatchNormParamType
<
T
>
x_sub_mean
=
static_cast
<
BatchNormParamType
<
T
>>
(
x
[
index
])
-
mean_val
;
y
[
index
]
=
scale
[
i
]
*
x_sub_mean
*
inv_var_val
+
bias
[
i
];
}
}
}
template
<
typename
T
>
class
BatchNormKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
...
...
@@ -80,8 +157,12 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
auto
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
#ifdef PADDLE_WITH_HIP
// HIP do not support compute format of NHWC
auto
compute_format
=
DataLayout
::
kNCHW
;
auto
compute_format
=
data_layout
==
DataLayout
::
kNHWC
?
DataLayout
::
kNHWC
:
DataLayout
::
kNCHW
;
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC
// auto compute_format = DataLayout::kNCHW;
#else
const
bool
fast_nhwc_batch_norm
=
test_mode
||
...
...
@@ -111,14 +192,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
// ------------------- cudnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
data_desc_
;
miopenTensorDescriptor_t
bn_param_desc_
;
miopenBatchNormMode_t
mode_
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
bn_param_desc_
));
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// miopenTensorDescriptor_t data_desc_;
// miopenTensorDescriptor_t bn_param_desc_;
// miopenBatchNormMode_t mode_;
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenCreateTensorDescriptor(&data_desc_));
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_));
#else
cudnnTensorDescriptor_t
data_desc_
;
cudnnTensorDescriptor_t
bn_param_desc_
;
...
...
@@ -138,7 +220,8 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
epsilon
=
std
::
max
(
epsilon
,
CUDNN_BN_MIN_EPSILON
);
#ifdef PADDLE_WITH_HIP
mode_
=
miopenBNSpatial
;
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// mode_ = miopenBNSpatial;
#elif CUDNN_VERSION_MIN(7, 0, 1)
if
(
FLAGS_cudnn_batchnorm_spatial_persistent
)
{
mode_
=
CUDNN_BATCHNORM_SPATIAL_PERSISTENT
;
...
...
@@ -161,14 +244,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
}
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
x_dims
.
size
()
>
3
?
x_dims
.
size
()
:
4
,
const_cast
<
int
*>
(
dims
.
data
()),
const_cast
<
int
*>
(
strides
.
data
())));
// Note: PERSISTENT not implemented for inference
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDeriveBNTensorDescriptor
(
bn_param_desc_
,
data_desc_
,
test_mode
?
miopenBNSpatial
:
mode_
));
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor(
// data_desc_, CudnnDataType<T>::type,
// x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()),
// const_cast<int *>(strides.data())));
// Note: PERSISTENT not implemented for inference
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenDeriveBNTensorDescriptor(
// bn_param_desc_, data_desc_, test_mode ? miopenBNSpatial : mode_));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
...
...
@@ -226,28 +310,53 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
C
,
est_var
->
dims
()[
0
],
est_var
->
dims
()));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenBatchNormalizationForwardInference
(
handle
,
miopenBNSpatial
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kOne
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kZero
())),
data_desc_
,
static_cast
<
const
void
*>
(
transformed_x
.
template
data
<
T
>()),
data_desc_
,
static_cast
<
void
*>
(
transformed_y
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
())),
bn_param_desc_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
scale
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
bias
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
est_var
->
template
data
<
BatchNormParamType
<
T
>
>
())),
epsilon
));
const
int
block_size
=
256
;
const
int
grid_size
=
(
N
*
C
*
H
*
W
*
D
+
block_size
-
1
)
/
block_size
;
if
(
compute_format
==
DataLayout
::
kNCHW
)
{
BNForwardInference
<
T
,
DataLayout
::
kNCHW
><<<
grid_size
,
block_size
,
0
,
dev_ctx
.
stream
()
>>>
(
transformed_x
.
template
data
<
T
>(),
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
est_var
->
template
data
<
BatchNormParamType
<
T
>
>
(),
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
bias
->
template
data
<
BatchNormParamType
<
T
>
>
(),
C
,
N
,
H
*
W
*
D
,
epsilon
,
transformed_y
.
template
data
<
T
>());
}
else
{
BNForwardInference
<
T
,
DataLayout
::
kNHWC
><<<
grid_size
,
block_size
,
0
,
dev_ctx
.
stream
()
>>>
(
transformed_x
.
template
data
<
T
>(),
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
est_var
->
template
data
<
BatchNormParamType
<
T
>
>
(),
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
bias
->
template
data
<
BatchNormParamType
<
T
>
>
(),
C
,
N
,
H
*
W
*
D
,
epsilon
,
transformed_y
.
template
data
<
T
>());
}
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenBatchNormalizationForwardInference(
// handle, miopenBNSpatial,
// const_cast<void *>(
// static_cast<const void *>(CudnnDataType<T>::kOne())),
// const_cast<void *>(
// static_cast<const void *>(CudnnDataType<T>::kZero())),
// data_desc_,
// static_cast<const void *>(transformed_x.template data<T>()),
// data_desc_,
// static_cast<void *>(
// transformed_y.template mutable_data<T>(ctx.GetPlace())),
// bn_param_desc_,
// const_cast<void *>(static_cast<const void *>(
// scale->template data<BatchNormParamType<T>>())),
// const_cast<void *>(static_cast<const void *>(
// bias->template data<BatchNormParamType<T>>())),
// const_cast<void *>(static_cast<const void *>(
// est_mean->template data<BatchNormParamType<T>>())),
// const_cast<void *>(static_cast<const void *>(
// est_var->template data<BatchNormParamType<T>>())),
// epsilon));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnBatchNormalizationForwardInference
(
...
...
@@ -365,34 +474,66 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
#endif // CUDNN_VERSION_MIN(7, 4, 1)
if
(
!
called
)
{
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenBatchNormalizationForwardTraining
(
handle
,
mode_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kOne
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kZero
())),
data_desc_
,
static_cast
<
const
void
*>
(
transformed_x
.
template
data
<
T
>()),
data_desc_
,
static_cast
<
void
*>
(
transformed_y
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
())),
bn_param_desc_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
scale
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
bias
->
template
data
<
BatchNormParamType
<
T
>
>
())),
this_factor
,
static_cast
<
void
*>
(
mean_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
static_cast
<
void
*>
(
variance_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
epsilon
,
static_cast
<
void
*>
(
saved_mean
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
static_cast
<
void
*>
(
saved_variance
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()))));
const
int
num
=
transformed_x
.
numel
();
const
int
block
=
256
;
const
int
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
const
int
max_blocks
=
std
::
max
(
max_threads
/
block
,
1
);
const
int
grid
=
std
::
min
(
C
,
max_blocks
);
if
(
compute_format
==
DataLayout
::
kNCHW
)
{
BNForwardTraining
<
T
,
block
,
DataLayout
::
kNCHW
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
transformed_x
.
template
data
<
T
>(),
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
bias
->
template
data
<
BatchNormParamType
<
T
>
>
(),
C
,
N
,
H
*
W
*
D
,
epsilon
,
this_factor
,
transformed_y
.
template
data
<
T
>(),
mean_out
->
template
data
<
BatchNormParamType
<
T
>
>
(),
variance_out
->
template
data
<
BatchNormParamType
<
T
>
>
(),
saved_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
saved_variance
->
template
data
<
BatchNormParamType
<
T
>
>
());
}
else
{
BNForwardTraining
<
T
,
block
,
DataLayout
::
kNHWC
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
transformed_x
.
template
data
<
T
>(),
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
bias
->
template
data
<
BatchNormParamType
<
T
>
>
(),
C
,
N
,
H
*
W
*
D
,
epsilon
,
this_factor
,
transformed_y
.
template
data
<
T
>(),
mean_out
->
template
data
<
BatchNormParamType
<
T
>
>
(),
variance_out
->
template
data
<
BatchNormParamType
<
T
>
>
(),
saved_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
saved_variance
->
template
data
<
BatchNormParamType
<
T
>
>
());
}
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenBatchNormalizationForwardTraining(
// handle, mode_, const_cast<void *>(static_cast<const void *>(
// CudnnDataType<T>::kOne())),
// const_cast<void *>(
// static_cast<const void *>(CudnnDataType<T>::kZero())),
// data_desc_,
// static_cast<const void *>(transformed_x.template data<T>()),
// data_desc_,
// static_cast<void *>(
// transformed_y.template mutable_data<T>(ctx.GetPlace())),
// bn_param_desc_,
// const_cast<void *>(static_cast<const void *>(
// scale->template data<BatchNormParamType<T>>())),
// const_cast<void *>(static_cast<const void *>(
// bias->template data<BatchNormParamType<T>>())),
// this_factor,
// static_cast<void *>(
// mean_out->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace())),
// static_cast<void *>(variance_out->template mutable_data<
// BatchNormParamType<T>>(ctx.GetPlace())),
// epsilon,
// static_cast<void *>(
// saved_mean->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace())),
// static_cast<void *>(saved_variance->template mutable_data<
// BatchNormParamType<T>>(ctx.GetPlace()))));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnBatchNormalizationForwardTraining
(
...
...
@@ -423,11 +564,12 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
ctx
,
&
transformed_y
,
y
);
}
#ifdef PADDLE_WITH_HIP
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
bn_param_desc_
));
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// clean when exit.
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenDestroyTensorDescriptor(data_desc_));
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_));
#else
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS
(
...
...
@@ -439,7 +581,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
};
template
<
typename
T
,
int
BlockDim
,
framework
::
DataLayout
layout
>
static
__global__
void
KeBNBackwardScaleBias
(
static
__global__
LAUNCH_BOUNDS
(
BlockDim
)
void
KeBNBackwardScaleBias
(
const
T
*
dy
,
const
T
*
x
,
const
BatchNormParamType
<
T
>
*
mean
,
const
BatchNormParamType
<
T
>
*
variance
,
const
double
epsilon
,
const
int
N
,
const
int
C
,
const
int
HxW
,
BatchNormParamType
<
T
>
*
dscale
,
...
...
@@ -526,13 +668,97 @@ class InplaceHelper {
};
template
<
typename
T
,
int
BlockDim
,
framework
::
DataLayout
layout
>
static
__global__
void
BNBackwardData
(
const
T
*
dy
,
const
BatchNormParamType
<
T
>
*
scale
,
const
BatchNormParamType
<
T
>
*
mean
,
const
T
*
x
,
const
BatchNormParamType
<
T
>
*
variance
,
const
int
C
,
const
int
N
,
const
int
HxW
,
T
*
dx
)
{
static
__global__
LAUNCH_BOUNDS
(
BlockDim
)
void
BNBackward
(
const
T
*
dy
,
const
T
*
x
,
const
BatchNormParamType
<
T
>
*
scale
,
const
BatchNormParamType
<
T
>
*
saved_mean
,
const
BatchNormParamType
<
T
>
*
saved_inv_variance
,
const
int
C
,
const
int
N
,
const
int
HxW
,
const
double
epsilon
,
T
*
dx
,
BatchNormParamType
<
T
>
*
dscale
,
BatchNormParamType
<
T
>
*
dbias
)
{
const
int
outer_size
=
C
;
const
int
inner_size
=
N
*
HxW
;
typedef
cub
::
BlockReduce
<
BatchNormParamType
<
T
>
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
ds_storage
;
__shared__
typename
BlockReduce
::
TempStorage
db_storage
;
__shared__
typename
BlockReduce
::
TempStorage
mean_storage
;
__shared__
typename
BlockReduce
::
TempStorage
variance_storeage
;
__shared__
BatchNormParamType
<
T
>
inv_var_val
;
__shared__
BatchNormParamType
<
T
>
mean_val
;
__shared__
BatchNormParamType
<
T
>
dscale_val
;
__shared__
BatchNormParamType
<
T
>
dbias_val
;
for
(
int
i
=
blockIdx
.
x
;
i
<
outer_size
;
i
+=
gridDim
.
x
)
{
BatchNormParamType
<
T
>
ds_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
BatchNormParamType
<
T
>
db_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
if
(
saved_mean
&&
saved_inv_variance
)
{
if
(
threadIdx
.
x
==
0
)
{
inv_var_val
=
saved_inv_variance
[
i
];
mean_val
=
saved_mean
[
i
];
}
}
else
{
BatchNormParamType
<
T
>
x_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
BatchNormParamType
<
T
>
x_square_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
for
(
int
j
=
threadIdx
.
x
;
j
<
inner_size
;
j
+=
blockDim
.
x
)
{
const
int
index
=
layout
==
framework
::
DataLayout
::
kNCHW
?
(
j
/
HxW
*
C
+
i
)
*
HxW
+
j
%
HxW
:
j
*
outer_size
+
i
;
BatchNormParamType
<
T
>
x_i
=
static_cast
<
BatchNormParamType
<
T
>>
(
x
[
index
]);
x_sum
+=
x_i
;
x_square_sum
+=
x_i
*
x_i
;
}
x_sum
=
BlockReduce
(
mean_storage
).
Reduce
(
x_sum
,
cub
::
Sum
());
x_square_sum
=
BlockReduce
(
variance_storeage
).
Reduce
(
x_square_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
mean_val
=
x_sum
/
inner_size
;
inv_var_val
=
1
/
sqrt
(
x_square_sum
/
inner_size
-
mean_val
*
mean_val
+
epsilon
);
}
}
__syncthreads
();
for
(
int
j
=
threadIdx
.
x
;
j
<
inner_size
;
j
+=
blockDim
.
x
)
{
const
int
index
=
layout
==
framework
::
DataLayout
::
kNCHW
?
(
j
/
HxW
*
C
+
i
)
*
HxW
+
j
%
HxW
:
j
*
outer_size
+
i
;
BatchNormParamType
<
T
>
dy_i
=
static_cast
<
BatchNormParamType
<
T
>>
(
dy
[
index
]);
ds_sum
+=
dy_i
*
(
static_cast
<
BatchNormParamType
<
T
>>
(
x
[
index
])
-
mean_val
);
db_sum
+=
dy_i
;
}
ds_sum
=
BlockReduce
(
ds_storage
).
Reduce
(
ds_sum
,
cub
::
Sum
());
db_sum
=
BlockReduce
(
db_storage
).
Reduce
(
db_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
dscale_val
=
ds_sum
*
inv_var_val
;
dbias_val
=
db_sum
;
dscale
[
i
]
=
dscale_val
;
dbias
[
i
]
=
dbias_val
;
}
__syncthreads
();
for
(
int
j
=
threadIdx
.
x
;
j
<
inner_size
;
j
+=
blockDim
.
x
)
{
const
int
index
=
layout
==
framework
::
DataLayout
::
kNCHW
?
(
j
/
HxW
*
C
+
i
)
*
HxW
+
j
%
HxW
:
j
*
outer_size
+
i
;
dx
[
index
]
=
scale
[
i
]
*
inv_var_val
*
(
static_cast
<
BatchNormParamType
<
T
>>
(
dy
[
index
])
-
dbias_val
/
static_cast
<
BatchNormParamType
<
T
>>
(
inner_size
)
-
(
static_cast
<
BatchNormParamType
<
T
>>
(
x
[
index
])
-
mean_val
)
*
inv_var_val
*
dscale_val
/
inner_size
);
}
}
}
template
<
typename
T
,
int
BlockDim
,
framework
::
DataLayout
layout
>
static
__global__
LAUNCH_BOUNDS
(
BlockDim
)
void
BNBackwardData
(
const
T
*
dy
,
const
BatchNormParamType
<
T
>
*
scale
,
const
BatchNormParamType
<
T
>
*
mean
,
const
T
*
x
,
const
BatchNormParamType
<
T
>
*
variance
,
const
int
C
,
const
int
N
,
const
int
HxW
,
T
*
dx
)
{
const
int
outer_size
=
C
;
const
int
inner_size
=
N
*
HxW
;
typedef
cub
::
BlockReduce
<
BatchNormParamType
<
T
>
,
BlockDim
>
BlockReduce
;
...
...
@@ -567,7 +793,6 @@ static __global__ void BNBackwardData(const T *dy,
dy_x_sub_mean_sum_val
=
dy_x_sub_mean_sum
;
}
__syncthreads
();
for
(
int
j
=
threadIdx
.
x
;
j
<
inner_size
;
j
+=
blockDim
.
x
)
{
const
int
index
=
layout
==
framework
::
DataLayout
::
kNCHW
?
(
j
/
HxW
*
C
+
i
)
*
HxW
+
j
%
HxW
...
...
@@ -668,8 +893,12 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
auto
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
const
auto
*
reserve_space
=
ctx
.
Input
<
Tensor
>
(
"ReserveSpace"
);
#ifdef PADDLE_WITH_HIP
// HIP do not support compute format of NHWC
auto
compute_format
=
DataLayout
::
kNCHW
;
auto
compute_format
=
data_layout
==
DataLayout
::
kNHWC
?
DataLayout
::
kNHWC
:
DataLayout
::
kNCHW
;
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC
// auto compute_format = DataLayout::kNCHW;
#else
const
bool
fast_nhwc_batch_norm
=
dtype
==
CUDNN_DATA_HALF
&&
FLAGS_cudnn_batchnorm_spatial_persistent
&&
...
...
@@ -714,7 +943,11 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
const
int
num
=
transformed_x
.
numel
();
#ifdef HIPCC
const
int
block
=
256
;
#else
const
int
block
=
512
;
#endif
int
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
const
int
max_blocks
=
std
::
max
(
max_threads
/
block
,
1
);
int
grid1
=
(
num
+
block
-
1
)
/
block
;
...
...
@@ -734,14 +967,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
// ------------------- cudnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t
data_desc_
;
miopenTensorDescriptor_t
bn_param_desc_
;
miopenBatchNormMode_t
mode_
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenCreateTensorDescriptor
(
&
bn_param_desc_
));
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// miopenTensorDescriptor_t data_desc_;
// miopenTensorDescriptor_t bn_param_desc_;
// miopenBatchNormMode_t mode_;
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenCreateTensorDescriptor(&data_desc_));
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_));
#else
cudnnTensorDescriptor_t
data_desc_
;
cudnnTensorDescriptor_t
bn_param_desc_
;
...
...
@@ -759,7 +993,8 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
}
epsilon
=
std
::
max
(
epsilon
,
CUDNN_BN_MIN_EPSILON
);
#ifdef PADDLE_WITH_HIP
mode_
=
miopenBNSpatial
;
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// mode_ = miopenBNSpatial;
#elif CUDNN_VERSION_MIN(7, 0, 1)
if
(
FLAGS_cudnn_batchnorm_spatial_persistent
)
{
mode_
=
CUDNN_BATCHNORM_SPATIAL_PERSISTENT
;
...
...
@@ -771,13 +1006,14 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
#endif // CUDNN_VERSION_MIN(7, 0, 1)
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenSetTensorDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
x_dims
.
size
()
>
3
?
x_dims
.
size
()
:
4
,
const_cast
<
int
*>
(
dims
.
data
()),
const_cast
<
int
*>
(
strides
.
data
())));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDeriveBNTensorDescriptor
(
bn_param_desc_
,
data_desc_
,
mode_
));
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor(
// data_desc_, CudnnDataType<T>::type,
// x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()),
// const_cast<int *>(strides.data())));
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenDeriveBNTensorDescriptor(bn_param_desc_,
// data_desc_, mode_));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
...
...
@@ -871,20 +1107,49 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
#endif // CUDNN_VERSION_MIN(7, 4, 1)
if
(
!
called
)
{
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenBatchNormalizationBackward
(
dev_ctx
.
cudnn_handle
(),
mode_
,
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
transformed_x
.
template
data
<
T
>(),
data_desc_
,
transformed_d_y
.
template
data
<
T
>(),
data_desc_
,
transformed_d_x
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
bn_param_desc_
,
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
d_scale
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
d_bias
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
epsilon
,
saved_mean_data
,
saved_var_data
));
if
(
compute_format
==
DataLayout
::
kNCHW
)
{
BNBackward
<
T
,
block
,
DataLayout
::
kNCHW
><<<
grid2
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
transformed_d_y
.
template
data
<
T
>(),
transformed_x
.
template
data
<
T
>(),
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
saved_mean_data
,
saved_var_data
,
C
,
N
,
H
*
W
*
D
,
epsilon
,
transformed_d_x
.
template
data
<
T
>(),
d_scale
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
d_bias
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()));
}
else
{
BNBackward
<
T
,
block
,
DataLayout
::
kNHWC
><<<
grid2
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
transformed_d_y
.
template
data
<
T
>(),
transformed_x
.
template
data
<
T
>(),
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
saved_mean_data
,
saved_var_data
,
C
,
N
,
H
*
W
*
D
,
epsilon
,
transformed_d_x
.
template
data
<
T
>(),
d_scale
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
d_bias
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()));
}
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenBatchNormalizationBackward(
// dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
// CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
// CudnnDataType<T>::kZero(), data_desc_,
// transformed_x.template data<T>(), data_desc_,
// transformed_d_y.template data<T>(), data_desc_,
// transformed_d_x.template mutable_data<T>(ctx.GetPlace()),
// bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
// d_scale->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace()),
// d_bias->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace()),
// epsilon, saved_mean_data, saved_var_data));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnBatchNormalizationBackward
(
...
...
@@ -931,11 +1196,12 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
}
#ifdef PADDLE_WITH_HIP
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
miopenDestroyTensorDescriptor
(
bn_param_desc_
));
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// clean when exit.
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenDestroyTensorDescriptor(data_desc_));
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_));
#else
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS
(
...
...
paddle/fluid/operators/norm_utils.cu.h
浏览文件 @
7879477f
...
...
@@ -32,6 +32,12 @@ namespace cub = hipcub;
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef __HIPCC__
#define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim)
#else
#define LAUNCH_BOUNDS(BlockDim)
#endif
namespace
paddle
{
namespace
operators
{
...
...
@@ -58,12 +64,10 @@ using DataLayout = framework::DataLayout;
// axis=(n,h,w)))
template
<
typename
T
,
int
BlockDim
,
framework
::
DataLayout
layout
>
__global__
void
DoubleGradComputeDX
(
const
T
*
x
,
const
T
*
mean
,
const
T
*
variance
,
const
T
*
ddx
,
const
T
*
dy
,
const
T
*
scale
,
const
T
*
ddscale
,
const
int
N
,
const
int
C
,
const
int
sample_size
,
const
double
epsilon
,
T
*
dx
)
{
__global__
LAUNCH_BOUNDS
(
BlockDim
)
void
DoubleGradComputeDX
(
const
T
*
x
,
const
T
*
mean
,
const
T
*
variance
,
const
T
*
ddx
,
const
T
*
dy
,
const
T
*
scale
,
const
T
*
ddscale
,
const
int
N
,
const
int
C
,
const
int
sample_size
,
const
double
epsilon
,
T
*
dx
)
{
const
int
outer_size
=
C
;
const
int
inner_size
=
N
*
sample_size
;
...
...
@@ -160,12 +164,10 @@ __global__ void DoubleGradComputeDX(const T *x, const T *mean,
// scale * inv_var * (ddx - (x - mean) * inv_var.pow(2) *
// np.mean(ddx * (x - mean), axis=(n,h,w)))
template
<
typename
T
,
int
BlockDim
,
framework
::
DataLayout
layout
>
__global__
void
DoubleGradComputeDDY
(
const
T
*
x
,
const
T
*
mean
,
const
T
*
variance
,
const
T
*
ddscale
,
const
T
*
ddbias
,
const
T
*
ddx
,
const
T
*
scale
,
const
int
N
,
const
int
C
,
const
int
sample_size
,
const
double
epsilon
,
T
*
ddy
)
{
__global__
LAUNCH_BOUNDS
(
BlockDim
)
void
DoubleGradComputeDDY
(
const
T
*
x
,
const
T
*
mean
,
const
T
*
variance
,
const
T
*
ddscale
,
const
T
*
ddbias
,
const
T
*
ddx
,
const
T
*
scale
,
const
int
N
,
const
int
C
,
const
int
sample_size
,
const
double
epsilon
,
T
*
ddy
)
{
const
int
outer_size
=
C
;
const
int
inner_size
=
N
*
sample_size
;
...
...
@@ -238,11 +240,10 @@ __global__ void DoubleGradComputeDDY(const T *x, const T *mean,
// inv_var.pow(2) * np.mean(dy * (x-mean), axis=(n,h,w)))) *
// ddx
template
<
typename
T
,
int
BlockDim
,
framework
::
DataLayout
layout
>
__global__
void
DoubleGradComputeDScale
(
const
T
*
x
,
const
T
*
mean
,
const
T
*
variance
,
const
T
*
ddx
,
const
T
*
dy
,
const
int
N
,
const
int
C
,
const
int
sample_size
,
const
double
epsilon
,
T
*
dscale
)
{
__global__
LAUNCH_BOUNDS
(
BlockDim
)
void
DoubleGradComputeDScale
(
const
T
*
x
,
const
T
*
mean
,
const
T
*
variance
,
const
T
*
ddx
,
const
T
*
dy
,
const
int
N
,
const
int
C
,
const
int
sample_size
,
const
double
epsilon
,
T
*
dscale
)
{
const
int
outer_size
=
C
;
const
int
inner_size
=
N
*
sample_size
;
...
...
@@ -302,7 +303,7 @@ __global__ void DoubleGradComputeDScale(const T *x, const T *mean,
// math: dscale = np.sum(ddx * dy, axis=(n,h,w)) * inv_var
template
<
typename
T
,
int
BlockDim
,
framework
::
DataLayout
layout
>
__global__
void
DoubleGradComputeDScaleWithGlobal
(
__global__
LAUNCH_BOUNDS
(
BlockDim
)
void
DoubleGradComputeDScaleWithGlobal
(
const
T
*
ddx
,
const
T
*
variance
,
const
T
*
dy
,
const
double
epsilon
,
const
int
N
,
const
int
C
,
const
int
sample_size
,
T
*
dscale
)
{
int
outer_size
=
C
;
...
...
@@ -422,8 +423,11 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
set_constant
(
dev_ctx
,
&
scale_tmp
,
static_cast
<
T
>
(
1
));
}
const
T
*
scale_data
=
Scale
?
Scale
->
data
<
T
>
()
:
scale_tmp
.
data
<
T
>
();
#ifdef __HIPCC__
const
int
block
=
256
;
#else
const
int
block
=
512
;
#endif
int
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
const
int
max_blocks
=
std
::
max
(
max_threads
/
block
,
1
);
int
grid
=
std
::
min
(
C
,
max_blocks
);
...
...
@@ -532,6 +536,5 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
}
}
}
}
// namespace operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录