未验证 提交 7879477f 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] add cuda kenrel for batch_norm_op (#32393)

上级 49773f36
...@@ -41,6 +41,83 @@ using CudnnDataType = platform::CudnnDataType<T>; ...@@ -41,6 +41,83 @@ using CudnnDataType = platform::CudnnDataType<T>;
template <typename T> template <typename T>
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType; using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;
template <typename T, framework::DataLayout layout>
static __global__ void BNForwardInference(
const T *x, const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *variance, const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *bias, const int C, const int N, const int HxW,
const double epsilon, T *y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int num = N * C * HxW;
for (int i = gid; i < num; i += stride) {
const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C;
BatchNormParamType<T> x_sub_mean =
static_cast<BatchNormParamType<T>>(x[i]) - mean[c];
BatchNormParamType<T> inv_var = 1 / sqrt(variance[c] + epsilon);
y[i] = static_cast<T>(scale[c] * x_sub_mean * inv_var + bias[c]);
}
}
template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining(
const T *x, const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *bias, const int C, const int N, const int HxW,
const double epsilon, double exponentialAverageFactor, T *y,
BatchNormParamType<T> *mean, BatchNormParamType<T> *variance,
BatchNormParamType<T> *save_mean,
BatchNormParamType<T> *save_inv_variance) {
int outer_size = C;
int inner_size = N * HxW;
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage mean_storage;
__shared__ typename BlockReduce::TempStorage variance_storeage;
__shared__ BatchNormParamType<T> mean_val;
__shared__ BatchNormParamType<T> variance_val;
__shared__ BatchNormParamType<T> inv_var_val;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
BatchNormParamType<T> x_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> x_square_sum = static_cast<BatchNormParamType<T>>(0);
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == framework::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
BatchNormParamType<T> x_i = static_cast<BatchNormParamType<T>>(x[index]);
x_sum += x_i;
x_square_sum += x_i * x_i;
}
x_sum = BlockReduce(mean_storage).Reduce(x_sum, cub::Sum());
x_square_sum =
BlockReduce(variance_storeage).Reduce(x_square_sum, cub::Sum());
if (threadIdx.x == 0) {
mean_val = x_sum / inner_size;
variance_val = x_square_sum / inner_size - mean_val * mean_val;
inv_var_val = 1 / sqrt(variance_val + epsilon);
if (save_mean && save_inv_variance) {
save_mean[i] = mean_val;
save_inv_variance[i] = inv_var_val;
}
mean[i] = (1 - exponentialAverageFactor) * mean_val +
exponentialAverageFactor * mean[i];
variance[i] = (1 - exponentialAverageFactor) * variance_val +
exponentialAverageFactor * variance[i];
}
__syncthreads();
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == framework::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
BatchNormParamType<T> x_sub_mean =
static_cast<BatchNormParamType<T>>(x[index]) - mean_val;
y[index] = scale[i] * x_sub_mean * inv_var_val + bias[i];
}
}
}
template <typename T> template <typename T>
class BatchNormKernel<platform::CUDADeviceContext, T> class BatchNormKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
...@@ -80,8 +157,12 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -80,8 +157,12 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
auto dtype = platform::CudnnDataType<T>::type; auto dtype = platform::CudnnDataType<T>::type;
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
// HIP do not support compute format of NHWC auto compute_format = data_layout == DataLayout::kNHWC ? DataLayout::kNHWC
auto compute_format = DataLayout::kNCHW; : DataLayout::kNCHW;
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC
// auto compute_format = DataLayout::kNCHW;
#else #else
const bool fast_nhwc_batch_norm = const bool fast_nhwc_batch_norm =
test_mode || test_mode ||
...@@ -111,14 +192,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -111,14 +192,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t data_desc_; // TODO(wangran16): wait for MIOpen to improve the performance of BN
miopenTensorDescriptor_t bn_param_desc_; // miopenTensorDescriptor_t data_desc_;
miopenBatchNormMode_t mode_; // miopenTensorDescriptor_t bn_param_desc_;
// miopenBatchNormMode_t mode_;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenCreateTensorDescriptor(&data_desc_)); // PADDLE_ENFORCE_CUDA_SUCCESS(
PADDLE_ENFORCE_CUDA_SUCCESS( // platform::dynload::miopenCreateTensorDescriptor(&data_desc_));
platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_)); // PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_));
#else #else
cudnnTensorDescriptor_t data_desc_; cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_; cudnnTensorDescriptor_t bn_param_desc_;
...@@ -138,7 +220,8 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -138,7 +220,8 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
mode_ = miopenBNSpatial; // TODO(wangran16): wait for MIOpen to improve the performance of BN
// mode_ = miopenBNSpatial;
#elif CUDNN_VERSION_MIN(7, 0, 1) #elif CUDNN_VERSION_MIN(7, 0, 1)
if (FLAGS_cudnn_batchnorm_spatial_persistent) { if (FLAGS_cudnn_batchnorm_spatial_persistent) {
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
...@@ -161,14 +244,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -161,14 +244,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
} }
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( // TODO(wangran16): wait for MIOpen to improve the performance of BN
data_desc_, CudnnDataType<T>::type, // PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor(
x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()), // data_desc_, CudnnDataType<T>::type,
const_cast<int *>(strides.data()))); // x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()),
// Note: PERSISTENT not implemented for inference // const_cast<int *>(strides.data())));
PADDLE_ENFORCE_CUDA_SUCCESS( // Note: PERSISTENT not implemented for inference
platform::dynload::miopenDeriveBNTensorDescriptor( // PADDLE_ENFORCE_CUDA_SUCCESS(
bn_param_desc_, data_desc_, test_mode ? miopenBNSpatial : mode_)); // platform::dynload::miopenDeriveBNTensorDescriptor(
// bn_param_desc_, data_desc_, test_mode ? miopenBNSpatial : mode_));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type, data_desc_, CudnnDataType<T>::type,
...@@ -226,28 +310,53 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -226,28 +310,53 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
C, est_var->dims()[0], est_var->dims())); C, est_var->dims()[0], est_var->dims()));
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS( const int block_size = 256;
platform::dynload::miopenBatchNormalizationForwardInference( const int grid_size = (N * C * H * W * D + block_size - 1) / block_size;
handle, miopenBNSpatial, if (compute_format == DataLayout::kNCHW) {
const_cast<void *>( BNForwardInference<
static_cast<const void *>(CudnnDataType<T>::kOne())), T,
const_cast<void *>( DataLayout::kNCHW><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
static_cast<const void *>(CudnnDataType<T>::kZero())), transformed_x.template data<T>(),
data_desc_, est_mean->template data<BatchNormParamType<T>>(),
static_cast<const void *>(transformed_x.template data<T>()), est_var->template data<BatchNormParamType<T>>(),
data_desc_, scale->template data<BatchNormParamType<T>>(),
static_cast<void *>( bias->template data<BatchNormParamType<T>>(), C, N, H * W * D,
transformed_y.template mutable_data<T>(ctx.GetPlace())), epsilon, transformed_y.template data<T>());
bn_param_desc_, } else {
const_cast<void *>(static_cast<const void *>( BNForwardInference<
scale->template data<BatchNormParamType<T>>())), T,
const_cast<void *>(static_cast<const void *>( DataLayout::kNHWC><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
bias->template data<BatchNormParamType<T>>())), transformed_x.template data<T>(),
const_cast<void *>(static_cast<const void *>( est_mean->template data<BatchNormParamType<T>>(),
est_mean->template data<BatchNormParamType<T>>())), est_var->template data<BatchNormParamType<T>>(),
const_cast<void *>(static_cast<const void *>( scale->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>())), bias->template data<BatchNormParamType<T>>(), C, N, H * W * D,
epsilon)); epsilon, transformed_y.template data<T>());
}
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenBatchNormalizationForwardInference(
// handle, miopenBNSpatial,
// const_cast<void *>(
// static_cast<const void *>(CudnnDataType<T>::kOne())),
// const_cast<void *>(
// static_cast<const void *>(CudnnDataType<T>::kZero())),
// data_desc_,
// static_cast<const void *>(transformed_x.template data<T>()),
// data_desc_,
// static_cast<void *>(
// transformed_y.template mutable_data<T>(ctx.GetPlace())),
// bn_param_desc_,
// const_cast<void *>(static_cast<const void *>(
// scale->template data<BatchNormParamType<T>>())),
// const_cast<void *>(static_cast<const void *>(
// bias->template data<BatchNormParamType<T>>())),
// const_cast<void *>(static_cast<const void *>(
// est_mean->template data<BatchNormParamType<T>>())),
// const_cast<void *>(static_cast<const void *>(
// est_var->template data<BatchNormParamType<T>>())),
// epsilon));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnBatchNormalizationForwardInference( platform::dynload::cudnnBatchNormalizationForwardInference(
...@@ -365,34 +474,66 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -365,34 +474,66 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
#endif // CUDNN_VERSION_MIN(7, 4, 1) #endif // CUDNN_VERSION_MIN(7, 4, 1)
if (!called) { if (!called) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS( const int num = transformed_x.numel();
platform::dynload::miopenBatchNormalizationForwardTraining( const int block = 256;
handle, mode_, const_cast<void *>(static_cast<const void *>( const int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
CudnnDataType<T>::kOne())), const int max_blocks = std::max(max_threads / block, 1);
const_cast<void *>( const int grid = std::min(C, max_blocks);
static_cast<const void *>(CudnnDataType<T>::kZero())), if (compute_format == DataLayout::kNCHW) {
data_desc_, BNForwardTraining<
static_cast<const void *>(transformed_x.template data<T>()), T, block,
data_desc_, DataLayout::kNCHW><<<grid, block, 0, dev_ctx.stream()>>>(
static_cast<void *>( transformed_x.template data<T>(),
transformed_y.template mutable_data<T>(ctx.GetPlace())), scale->template data<BatchNormParamType<T>>(),
bn_param_desc_, bias->template data<BatchNormParamType<T>>(), C, N, H * W * D,
const_cast<void *>(static_cast<const void *>( epsilon, this_factor, transformed_y.template data<T>(),
scale->template data<BatchNormParamType<T>>())), mean_out->template data<BatchNormParamType<T>>(),
const_cast<void *>(static_cast<const void *>( variance_out->template data<BatchNormParamType<T>>(),
bias->template data<BatchNormParamType<T>>())), saved_mean->template data<BatchNormParamType<T>>(),
this_factor, saved_variance->template data<BatchNormParamType<T>>());
static_cast<void *>( } else {
mean_out->template mutable_data<BatchNormParamType<T>>( BNForwardTraining<
ctx.GetPlace())), T, block,
static_cast<void *>(variance_out->template mutable_data< DataLayout::kNHWC><<<grid, block, 0, dev_ctx.stream()>>>(
BatchNormParamType<T>>(ctx.GetPlace())), transformed_x.template data<T>(),
epsilon, scale->template data<BatchNormParamType<T>>(),
static_cast<void *>( bias->template data<BatchNormParamType<T>>(), C, N, H * W * D,
saved_mean->template mutable_data<BatchNormParamType<T>>( epsilon, this_factor, transformed_y.template data<T>(),
ctx.GetPlace())), mean_out->template data<BatchNormParamType<T>>(),
static_cast<void *>(saved_variance->template mutable_data< variance_out->template data<BatchNormParamType<T>>(),
BatchNormParamType<T>>(ctx.GetPlace())))); saved_mean->template data<BatchNormParamType<T>>(),
saved_variance->template data<BatchNormParamType<T>>());
}
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenBatchNormalizationForwardTraining(
// handle, mode_, const_cast<void *>(static_cast<const void *>(
// CudnnDataType<T>::kOne())),
// const_cast<void *>(
// static_cast<const void *>(CudnnDataType<T>::kZero())),
// data_desc_,
// static_cast<const void *>(transformed_x.template data<T>()),
// data_desc_,
// static_cast<void *>(
// transformed_y.template mutable_data<T>(ctx.GetPlace())),
// bn_param_desc_,
// const_cast<void *>(static_cast<const void *>(
// scale->template data<BatchNormParamType<T>>())),
// const_cast<void *>(static_cast<const void *>(
// bias->template data<BatchNormParamType<T>>())),
// this_factor,
// static_cast<void *>(
// mean_out->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace())),
// static_cast<void *>(variance_out->template mutable_data<
// BatchNormParamType<T>>(ctx.GetPlace())),
// epsilon,
// static_cast<void *>(
// saved_mean->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace())),
// static_cast<void *>(saved_variance->template mutable_data<
// BatchNormParamType<T>>(ctx.GetPlace()))));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnBatchNormalizationForwardTraining( platform::dynload::cudnnBatchNormalizationForwardTraining(
...@@ -423,11 +564,12 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -423,11 +564,12 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
ctx, &transformed_y, y); ctx, &transformed_y, y);
} }
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
// clean when exit. // TODO(wangran16): wait for MIOpen to improve the performance of BN
PADDLE_ENFORCE_CUDA_SUCCESS( // clean when exit.
platform::dynload::miopenDestroyTensorDescriptor(data_desc_)); // PADDLE_ENFORCE_CUDA_SUCCESS(
PADDLE_ENFORCE_CUDA_SUCCESS( // platform::dynload::miopenDestroyTensorDescriptor(data_desc_));
platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_)); // PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_));
#else #else
// clean when exit. // clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
...@@ -439,7 +581,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -439,7 +581,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
}; };
template <typename T, int BlockDim, framework::DataLayout layout> template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ void KeBNBackwardScaleBias( static __global__ LAUNCH_BOUNDS(BlockDim) void KeBNBackwardScaleBias(
const T *dy, const T *x, const BatchNormParamType<T> *mean, const T *dy, const T *x, const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *variance, const double epsilon, const int N, const BatchNormParamType<T> *variance, const double epsilon, const int N,
const int C, const int HxW, BatchNormParamType<T> *dscale, const int C, const int HxW, BatchNormParamType<T> *dscale,
...@@ -526,13 +668,97 @@ class InplaceHelper { ...@@ -526,13 +668,97 @@ class InplaceHelper {
}; };
template <typename T, int BlockDim, framework::DataLayout layout> template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ void BNBackwardData(const T *dy, static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackward(
const BatchNormParamType<T> *scale, const T *dy, const T *x, const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *mean, const BatchNormParamType<T> *saved_mean,
const T *x, const BatchNormParamType<T> *saved_inv_variance, const int C, const int N,
const BatchNormParamType<T> *variance, const int HxW, const double epsilon, T *dx, BatchNormParamType<T> *dscale,
const int C, const int N, const int HxW, BatchNormParamType<T> *dbias) {
T *dx) { const int outer_size = C;
const int inner_size = N * HxW;
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage ds_storage;
__shared__ typename BlockReduce::TempStorage db_storage;
__shared__ typename BlockReduce::TempStorage mean_storage;
__shared__ typename BlockReduce::TempStorage variance_storeage;
__shared__ BatchNormParamType<T> inv_var_val;
__shared__ BatchNormParamType<T> mean_val;
__shared__ BatchNormParamType<T> dscale_val;
__shared__ BatchNormParamType<T> dbias_val;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
BatchNormParamType<T> ds_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> db_sum = static_cast<BatchNormParamType<T>>(0);
if (saved_mean && saved_inv_variance) {
if (threadIdx.x == 0) {
inv_var_val = saved_inv_variance[i];
mean_val = saved_mean[i];
}
} else {
BatchNormParamType<T> x_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> x_square_sum =
static_cast<BatchNormParamType<T>>(0);
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == framework::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
BatchNormParamType<T> x_i =
static_cast<BatchNormParamType<T>>(x[index]);
x_sum += x_i;
x_square_sum += x_i * x_i;
}
x_sum = BlockReduce(mean_storage).Reduce(x_sum, cub::Sum());
x_square_sum =
BlockReduce(variance_storeage).Reduce(x_square_sum, cub::Sum());
if (threadIdx.x == 0) {
mean_val = x_sum / inner_size;
inv_var_val =
1 / sqrt(x_square_sum / inner_size - mean_val * mean_val + epsilon);
}
}
__syncthreads();
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == framework::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
BatchNormParamType<T> dy_i =
static_cast<BatchNormParamType<T>>(dy[index]);
ds_sum +=
dy_i * (static_cast<BatchNormParamType<T>>(x[index]) - mean_val);
db_sum += dy_i;
}
ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum());
db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum());
if (threadIdx.x == 0) {
dscale_val = ds_sum * inv_var_val;
dbias_val = db_sum;
dscale[i] = dscale_val;
dbias[i] = dbias_val;
}
__syncthreads();
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == framework::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
dx[index] = scale[i] * inv_var_val *
(static_cast<BatchNormParamType<T>>(dy[index]) -
dbias_val / static_cast<BatchNormParamType<T>>(inner_size) -
(static_cast<BatchNormParamType<T>>(x[index]) - mean_val) *
inv_var_val * dscale_val / inner_size);
}
}
}
template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData(
const T *dy, const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *mean, const T *x,
const BatchNormParamType<T> *variance, const int C, const int N,
const int HxW, T *dx) {
const int outer_size = C; const int outer_size = C;
const int inner_size = N * HxW; const int inner_size = N * HxW;
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce; typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
...@@ -567,7 +793,6 @@ static __global__ void BNBackwardData(const T *dy, ...@@ -567,7 +793,6 @@ static __global__ void BNBackwardData(const T *dy,
dy_x_sub_mean_sum_val = dy_x_sub_mean_sum; dy_x_sub_mean_sum_val = dy_x_sub_mean_sum;
} }
__syncthreads(); __syncthreads();
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == framework::DataLayout::kNCHW const int index = layout == framework::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW ? (j / HxW * C + i) * HxW + j % HxW
...@@ -668,8 +893,12 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -668,8 +893,12 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
auto dtype = platform::CudnnDataType<T>::type; auto dtype = platform::CudnnDataType<T>::type;
const auto *reserve_space = ctx.Input<Tensor>("ReserveSpace"); const auto *reserve_space = ctx.Input<Tensor>("ReserveSpace");
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
// HIP do not support compute format of NHWC auto compute_format = data_layout == DataLayout::kNHWC ? DataLayout::kNHWC
auto compute_format = DataLayout::kNCHW; : DataLayout::kNCHW;
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC
// auto compute_format = DataLayout::kNCHW;
#else #else
const bool fast_nhwc_batch_norm = const bool fast_nhwc_batch_norm =
dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent && dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent &&
...@@ -714,7 +943,11 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -714,7 +943,11 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
const int num = transformed_x.numel(); const int num = transformed_x.numel();
#ifdef HIPCC
const int block = 256;
#else
const int block = 512; const int block = 512;
#endif
int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1); const int max_blocks = std::max(max_threads / block, 1);
int grid1 = (num + block - 1) / block; int grid1 = (num + block - 1) / block;
...@@ -734,14 +967,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -734,14 +967,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t data_desc_; // TODO(wangran16): wait for MIOpen to improve the performance of BN
miopenTensorDescriptor_t bn_param_desc_; // miopenTensorDescriptor_t data_desc_;
miopenBatchNormMode_t mode_; // miopenTensorDescriptor_t bn_param_desc_;
// miopenBatchNormMode_t mode_;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenCreateTensorDescriptor(&data_desc_)); // PADDLE_ENFORCE_CUDA_SUCCESS(
PADDLE_ENFORCE_CUDA_SUCCESS( // platform::dynload::miopenCreateTensorDescriptor(&data_desc_));
platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_)); // PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_));
#else #else
cudnnTensorDescriptor_t data_desc_; cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_; cudnnTensorDescriptor_t bn_param_desc_;
...@@ -759,7 +993,8 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -759,7 +993,8 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
} }
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
mode_ = miopenBNSpatial; // TODO(wangran16): wait for MIOpen to improve the performance of BN
// mode_ = miopenBNSpatial;
#elif CUDNN_VERSION_MIN(7, 0, 1) #elif CUDNN_VERSION_MIN(7, 0, 1)
if (FLAGS_cudnn_batchnorm_spatial_persistent) { if (FLAGS_cudnn_batchnorm_spatial_persistent) {
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
...@@ -771,13 +1006,14 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -771,13 +1006,14 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
#endif // CUDNN_VERSION_MIN(7, 0, 1) #endif // CUDNN_VERSION_MIN(7, 0, 1)
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor( // TODO(wangran16): wait for MIOpen to improve the performance of BN
data_desc_, CudnnDataType<T>::type, // PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor(
x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()), // data_desc_, CudnnDataType<T>::type,
const_cast<int *>(strides.data()))); // x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()),
PADDLE_ENFORCE_CUDA_SUCCESS( // const_cast<int *>(strides.data())));
platform::dynload::miopenDeriveBNTensorDescriptor(bn_param_desc_, // PADDLE_ENFORCE_CUDA_SUCCESS(
data_desc_, mode_)); // platform::dynload::miopenDeriveBNTensorDescriptor(bn_param_desc_,
// data_desc_, mode_));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type, data_desc_, CudnnDataType<T>::type,
...@@ -871,20 +1107,49 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -871,20 +1107,49 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
#endif // CUDNN_VERSION_MIN(7, 4, 1) #endif // CUDNN_VERSION_MIN(7, 4, 1)
if (!called) { if (!called) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS( if (compute_format == DataLayout::kNCHW) {
platform::dynload::miopenBatchNormalizationBackward( BNBackward<
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(), T, block,
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(), DataLayout::kNCHW><<<grid2, block, 0, dev_ctx.stream()>>>(
CudnnDataType<T>::kZero(), data_desc_, transformed_d_y.template data<T>(),
transformed_x.template data<T>(), data_desc_, transformed_x.template data<T>(),
transformed_d_y.template data<T>(), data_desc_, scale->template data<BatchNormParamType<T>>(), saved_mean_data,
transformed_d_x.template mutable_data<T>(ctx.GetPlace()), saved_var_data, C, N, H * W * D, epsilon,
bn_param_desc_, scale->template data<BatchNormParamType<T>>(), transformed_d_x.template data<T>(),
d_scale->template mutable_data<BatchNormParamType<T>>( d_scale->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()), ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>( d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()), ctx.GetPlace()));
epsilon, saved_mean_data, saved_var_data)); } else {
BNBackward<
T, block,
DataLayout::kNHWC><<<grid2, block, 0, dev_ctx.stream()>>>(
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
scale->template data<BatchNormParamType<T>>(), saved_mean_data,
saved_var_data, C, N, H * W * D, epsilon,
transformed_d_x.template data<T>(),
d_scale->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()));
}
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenBatchNormalizationBackward(
// dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
// CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
// CudnnDataType<T>::kZero(), data_desc_,
// transformed_x.template data<T>(), data_desc_,
// transformed_d_y.template data<T>(), data_desc_,
// transformed_d_x.template mutable_data<T>(ctx.GetPlace()),
// bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
// d_scale->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace()),
// d_bias->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace()),
// epsilon, saved_mean_data, saved_var_data));
#else #else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnBatchNormalizationBackward( platform::dynload::cudnnBatchNormalizationBackward(
...@@ -931,11 +1196,12 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -931,11 +1196,12 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
} }
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
// clean when exit. // TODO(wangran16): wait for MIOpen to improve the performance of BN
PADDLE_ENFORCE_CUDA_SUCCESS( // clean when exit.
platform::dynload::miopenDestroyTensorDescriptor(data_desc_)); // PADDLE_ENFORCE_CUDA_SUCCESS(
PADDLE_ENFORCE_CUDA_SUCCESS( // platform::dynload::miopenDestroyTensorDescriptor(data_desc_));
platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_)); // PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_));
#else #else
// clean when exit. // clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
......
...@@ -32,6 +32,12 @@ namespace cub = hipcub; ...@@ -32,6 +32,12 @@ namespace cub = hipcub;
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#endif #endif
#ifdef __HIPCC__
#define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim)
#else
#define LAUNCH_BOUNDS(BlockDim)
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -58,12 +64,10 @@ using DataLayout = framework::DataLayout; ...@@ -58,12 +64,10 @@ using DataLayout = framework::DataLayout;
// axis=(n,h,w))) // axis=(n,h,w)))
template <typename T, int BlockDim, framework::DataLayout layout> template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDX(const T *x, const T *mean, __global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDX(
const T *variance, const T *ddx, const T *x, const T *mean, const T *variance, const T *ddx, const T *dy,
const T *dy, const T *scale, const T *scale, const T *ddscale, const int N, const int C,
const T *ddscale, const int N, const int C, const int sample_size, const double epsilon, T *dx) {
const int sample_size, const double epsilon,
T *dx) {
const int outer_size = C; const int outer_size = C;
const int inner_size = N * sample_size; const int inner_size = N * sample_size;
...@@ -160,12 +164,10 @@ __global__ void DoubleGradComputeDX(const T *x, const T *mean, ...@@ -160,12 +164,10 @@ __global__ void DoubleGradComputeDX(const T *x, const T *mean,
// scale * inv_var * (ddx - (x - mean) * inv_var.pow(2) * // scale * inv_var * (ddx - (x - mean) * inv_var.pow(2) *
// np.mean(ddx * (x - mean), axis=(n,h,w))) // np.mean(ddx * (x - mean), axis=(n,h,w)))
template <typename T, int BlockDim, framework::DataLayout layout> template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDDY(const T *x, const T *mean, __global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDDY(
const T *variance, const T *ddscale, const T *x, const T *mean, const T *variance, const T *ddscale,
const T *ddbias, const T *ddx, const T *ddbias, const T *ddx, const T *scale, const int N, const int C,
const T *scale, const int N, const int C, const int sample_size, const double epsilon, T *ddy) {
const int sample_size,
const double epsilon, T *ddy) {
const int outer_size = C; const int outer_size = C;
const int inner_size = N * sample_size; const int inner_size = N * sample_size;
...@@ -238,11 +240,10 @@ __global__ void DoubleGradComputeDDY(const T *x, const T *mean, ...@@ -238,11 +240,10 @@ __global__ void DoubleGradComputeDDY(const T *x, const T *mean,
// inv_var.pow(2) * np.mean(dy * (x-mean), axis=(n,h,w)))) * // inv_var.pow(2) * np.mean(dy * (x-mean), axis=(n,h,w)))) *
// ddx // ddx
template <typename T, int BlockDim, framework::DataLayout layout> template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDScale(const T *x, const T *mean, __global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDScale(
const T *variance, const T *ddx, const T *x, const T *mean, const T *variance, const T *ddx, const T *dy,
const T *dy, const int N, const int C, const int N, const int C, const int sample_size, const double epsilon,
const int sample_size, T *dscale) {
const double epsilon, T *dscale) {
const int outer_size = C; const int outer_size = C;
const int inner_size = N * sample_size; const int inner_size = N * sample_size;
...@@ -302,7 +303,7 @@ __global__ void DoubleGradComputeDScale(const T *x, const T *mean, ...@@ -302,7 +303,7 @@ __global__ void DoubleGradComputeDScale(const T *x, const T *mean,
// math: dscale = np.sum(ddx * dy, axis=(n,h,w)) * inv_var // math: dscale = np.sum(ddx * dy, axis=(n,h,w)) * inv_var
template <typename T, int BlockDim, framework::DataLayout layout> template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDScaleWithGlobal( __global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDScaleWithGlobal(
const T *ddx, const T *variance, const T *dy, const double epsilon, const T *ddx, const T *variance, const T *dy, const double epsilon,
const int N, const int C, const int sample_size, T *dscale) { const int N, const int C, const int sample_size, T *dscale) {
int outer_size = C; int outer_size = C;
...@@ -422,8 +423,11 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx, ...@@ -422,8 +423,11 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
set_constant(dev_ctx, &scale_tmp, static_cast<T>(1)); set_constant(dev_ctx, &scale_tmp, static_cast<T>(1));
} }
const T *scale_data = Scale ? Scale->data<T>() : scale_tmp.data<T>(); const T *scale_data = Scale ? Scale->data<T>() : scale_tmp.data<T>();
#ifdef __HIPCC__
const int block = 256;
#else
const int block = 512; const int block = 512;
#endif
int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1); const int max_blocks = std::max(max_threads / block, 1);
int grid = std::min(C, max_blocks); int grid = std::min(C, max_blocks);
...@@ -532,6 +536,5 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx, ...@@ -532,6 +536,5 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
} }
} }
} }
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册