未验证 提交 7879477f 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] add cuda kenrel for batch_norm_op (#32393)

上级 49773f36
......@@ -41,6 +41,83 @@ using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;
template <typename T, framework::DataLayout layout>
static __global__ void BNForwardInference(
const T *x, const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *variance, const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *bias, const int C, const int N, const int HxW,
const double epsilon, T *y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int num = N * C * HxW;
for (int i = gid; i < num; i += stride) {
const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C;
BatchNormParamType<T> x_sub_mean =
static_cast<BatchNormParamType<T>>(x[i]) - mean[c];
BatchNormParamType<T> inv_var = 1 / sqrt(variance[c] + epsilon);
y[i] = static_cast<T>(scale[c] * x_sub_mean * inv_var + bias[c]);
}
}
template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining(
const T *x, const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *bias, const int C, const int N, const int HxW,
const double epsilon, double exponentialAverageFactor, T *y,
BatchNormParamType<T> *mean, BatchNormParamType<T> *variance,
BatchNormParamType<T> *save_mean,
BatchNormParamType<T> *save_inv_variance) {
int outer_size = C;
int inner_size = N * HxW;
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage mean_storage;
__shared__ typename BlockReduce::TempStorage variance_storeage;
__shared__ BatchNormParamType<T> mean_val;
__shared__ BatchNormParamType<T> variance_val;
__shared__ BatchNormParamType<T> inv_var_val;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
BatchNormParamType<T> x_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> x_square_sum = static_cast<BatchNormParamType<T>>(0);
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == framework::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
BatchNormParamType<T> x_i = static_cast<BatchNormParamType<T>>(x[index]);
x_sum += x_i;
x_square_sum += x_i * x_i;
}
x_sum = BlockReduce(mean_storage).Reduce(x_sum, cub::Sum());
x_square_sum =
BlockReduce(variance_storeage).Reduce(x_square_sum, cub::Sum());
if (threadIdx.x == 0) {
mean_val = x_sum / inner_size;
variance_val = x_square_sum / inner_size - mean_val * mean_val;
inv_var_val = 1 / sqrt(variance_val + epsilon);
if (save_mean && save_inv_variance) {
save_mean[i] = mean_val;
save_inv_variance[i] = inv_var_val;
}
mean[i] = (1 - exponentialAverageFactor) * mean_val +
exponentialAverageFactor * mean[i];
variance[i] = (1 - exponentialAverageFactor) * variance_val +
exponentialAverageFactor * variance[i];
}
__syncthreads();
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == framework::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
BatchNormParamType<T> x_sub_mean =
static_cast<BatchNormParamType<T>>(x[index]) - mean_val;
y[index] = scale[i] * x_sub_mean * inv_var_val + bias[i];
}
}
}
template <typename T>
class BatchNormKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
......@@ -80,8 +157,12 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
auto dtype = platform::CudnnDataType<T>::type;
#ifdef PADDLE_WITH_HIP
// HIP do not support compute format of NHWC
auto compute_format = DataLayout::kNCHW;
auto compute_format = data_layout == DataLayout::kNHWC ? DataLayout::kNHWC
: DataLayout::kNCHW;
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC
// auto compute_format = DataLayout::kNCHW;
#else
const bool fast_nhwc_batch_norm =
test_mode ||
......@@ -111,14 +192,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
// ------------------- cudnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t data_desc_;
miopenTensorDescriptor_t bn_param_desc_;
miopenBatchNormMode_t mode_;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenCreateTensorDescriptor(&data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_));
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// miopenTensorDescriptor_t data_desc_;
// miopenTensorDescriptor_t bn_param_desc_;
// miopenBatchNormMode_t mode_;
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenCreateTensorDescriptor(&data_desc_));
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_));
#else
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
......@@ -138,7 +220,8 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#ifdef PADDLE_WITH_HIP
mode_ = miopenBNSpatial;
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// mode_ = miopenBNSpatial;
#elif CUDNN_VERSION_MIN(7, 0, 1)
if (FLAGS_cudnn_batchnorm_spatial_persistent) {
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
......@@ -161,14 +244,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
}
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()),
const_cast<int *>(strides.data())));
// Note: PERSISTENT not implemented for inference
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, test_mode ? miopenBNSpatial : mode_));
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor(
// data_desc_, CudnnDataType<T>::type,
// x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()),
// const_cast<int *>(strides.data())));
// Note: PERSISTENT not implemented for inference
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenDeriveBNTensorDescriptor(
// bn_param_desc_, data_desc_, test_mode ? miopenBNSpatial : mode_));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
......@@ -226,28 +310,53 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
C, est_var->dims()[0], est_var->dims()));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenBatchNormalizationForwardInference(
handle, miopenBNSpatial,
const_cast<void *>(
static_cast<const void *>(CudnnDataType<T>::kOne())),
const_cast<void *>(
static_cast<const void *>(CudnnDataType<T>::kZero())),
data_desc_,
static_cast<const void *>(transformed_x.template data<T>()),
data_desc_,
static_cast<void *>(
transformed_y.template mutable_data<T>(ctx.GetPlace())),
bn_param_desc_,
const_cast<void *>(static_cast<const void *>(
scale->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
bias->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
est_mean->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
est_var->template data<BatchNormParamType<T>>())),
epsilon));
const int block_size = 256;
const int grid_size = (N * C * H * W * D + block_size - 1) / block_size;
if (compute_format == DataLayout::kNCHW) {
BNForwardInference<
T,
DataLayout::kNCHW><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(),
scale->template data<BatchNormParamType<T>>(),
bias->template data<BatchNormParamType<T>>(), C, N, H * W * D,
epsilon, transformed_y.template data<T>());
} else {
BNForwardInference<
T,
DataLayout::kNHWC><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(),
scale->template data<BatchNormParamType<T>>(),
bias->template data<BatchNormParamType<T>>(), C, N, H * W * D,
epsilon, transformed_y.template data<T>());
}
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenBatchNormalizationForwardInference(
// handle, miopenBNSpatial,
// const_cast<void *>(
// static_cast<const void *>(CudnnDataType<T>::kOne())),
// const_cast<void *>(
// static_cast<const void *>(CudnnDataType<T>::kZero())),
// data_desc_,
// static_cast<const void *>(transformed_x.template data<T>()),
// data_desc_,
// static_cast<void *>(
// transformed_y.template mutable_data<T>(ctx.GetPlace())),
// bn_param_desc_,
// const_cast<void *>(static_cast<const void *>(
// scale->template data<BatchNormParamType<T>>())),
// const_cast<void *>(static_cast<const void *>(
// bias->template data<BatchNormParamType<T>>())),
// const_cast<void *>(static_cast<const void *>(
// est_mean->template data<BatchNormParamType<T>>())),
// const_cast<void *>(static_cast<const void *>(
// est_var->template data<BatchNormParamType<T>>())),
// epsilon));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnBatchNormalizationForwardInference(
......@@ -365,34 +474,66 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
#endif // CUDNN_VERSION_MIN(7, 4, 1)
if (!called) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenBatchNormalizationForwardTraining(
handle, mode_, const_cast<void *>(static_cast<const void *>(
CudnnDataType<T>::kOne())),
const_cast<void *>(
static_cast<const void *>(CudnnDataType<T>::kZero())),
data_desc_,
static_cast<const void *>(transformed_x.template data<T>()),
data_desc_,
static_cast<void *>(
transformed_y.template mutable_data<T>(ctx.GetPlace())),
bn_param_desc_,
const_cast<void *>(static_cast<const void *>(
scale->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
bias->template data<BatchNormParamType<T>>())),
this_factor,
static_cast<void *>(
mean_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())),
static_cast<void *>(variance_out->template mutable_data<
BatchNormParamType<T>>(ctx.GetPlace())),
epsilon,
static_cast<void *>(
saved_mean->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())),
static_cast<void *>(saved_variance->template mutable_data<
BatchNormParamType<T>>(ctx.GetPlace()))));
const int num = transformed_x.numel();
const int block = 256;
const int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
const int grid = std::min(C, max_blocks);
if (compute_format == DataLayout::kNCHW) {
BNForwardTraining<
T, block,
DataLayout::kNCHW><<<grid, block, 0, dev_ctx.stream()>>>(
transformed_x.template data<T>(),
scale->template data<BatchNormParamType<T>>(),
bias->template data<BatchNormParamType<T>>(), C, N, H * W * D,
epsilon, this_factor, transformed_y.template data<T>(),
mean_out->template data<BatchNormParamType<T>>(),
variance_out->template data<BatchNormParamType<T>>(),
saved_mean->template data<BatchNormParamType<T>>(),
saved_variance->template data<BatchNormParamType<T>>());
} else {
BNForwardTraining<
T, block,
DataLayout::kNHWC><<<grid, block, 0, dev_ctx.stream()>>>(
transformed_x.template data<T>(),
scale->template data<BatchNormParamType<T>>(),
bias->template data<BatchNormParamType<T>>(), C, N, H * W * D,
epsilon, this_factor, transformed_y.template data<T>(),
mean_out->template data<BatchNormParamType<T>>(),
variance_out->template data<BatchNormParamType<T>>(),
saved_mean->template data<BatchNormParamType<T>>(),
saved_variance->template data<BatchNormParamType<T>>());
}
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenBatchNormalizationForwardTraining(
// handle, mode_, const_cast<void *>(static_cast<const void *>(
// CudnnDataType<T>::kOne())),
// const_cast<void *>(
// static_cast<const void *>(CudnnDataType<T>::kZero())),
// data_desc_,
// static_cast<const void *>(transformed_x.template data<T>()),
// data_desc_,
// static_cast<void *>(
// transformed_y.template mutable_data<T>(ctx.GetPlace())),
// bn_param_desc_,
// const_cast<void *>(static_cast<const void *>(
// scale->template data<BatchNormParamType<T>>())),
// const_cast<void *>(static_cast<const void *>(
// bias->template data<BatchNormParamType<T>>())),
// this_factor,
// static_cast<void *>(
// mean_out->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace())),
// static_cast<void *>(variance_out->template mutable_data<
// BatchNormParamType<T>>(ctx.GetPlace())),
// epsilon,
// static_cast<void *>(
// saved_mean->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace())),
// static_cast<void *>(saved_variance->template mutable_data<
// BatchNormParamType<T>>(ctx.GetPlace()))));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnBatchNormalizationForwardTraining(
......@@ -423,11 +564,12 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
ctx, &transformed_y, y);
}
#ifdef PADDLE_WITH_HIP
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenDestroyTensorDescriptor(data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_));
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// clean when exit.
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenDestroyTensorDescriptor(data_desc_));
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_));
#else
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS(
......@@ -439,7 +581,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
};
template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ void KeBNBackwardScaleBias(
static __global__ LAUNCH_BOUNDS(BlockDim) void KeBNBackwardScaleBias(
const T *dy, const T *x, const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *variance, const double epsilon, const int N,
const int C, const int HxW, BatchNormParamType<T> *dscale,
......@@ -526,13 +668,97 @@ class InplaceHelper {
};
template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ void BNBackwardData(const T *dy,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *mean,
const T *x,
const BatchNormParamType<T> *variance,
const int C, const int N, const int HxW,
T *dx) {
static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackward(
const T *dy, const T *x, const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *saved_mean,
const BatchNormParamType<T> *saved_inv_variance, const int C, const int N,
const int HxW, const double epsilon, T *dx, BatchNormParamType<T> *dscale,
BatchNormParamType<T> *dbias) {
const int outer_size = C;
const int inner_size = N * HxW;
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage ds_storage;
__shared__ typename BlockReduce::TempStorage db_storage;
__shared__ typename BlockReduce::TempStorage mean_storage;
__shared__ typename BlockReduce::TempStorage variance_storeage;
__shared__ BatchNormParamType<T> inv_var_val;
__shared__ BatchNormParamType<T> mean_val;
__shared__ BatchNormParamType<T> dscale_val;
__shared__ BatchNormParamType<T> dbias_val;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
BatchNormParamType<T> ds_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> db_sum = static_cast<BatchNormParamType<T>>(0);
if (saved_mean && saved_inv_variance) {
if (threadIdx.x == 0) {
inv_var_val = saved_inv_variance[i];
mean_val = saved_mean[i];
}
} else {
BatchNormParamType<T> x_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> x_square_sum =
static_cast<BatchNormParamType<T>>(0);
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == framework::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
BatchNormParamType<T> x_i =
static_cast<BatchNormParamType<T>>(x[index]);
x_sum += x_i;
x_square_sum += x_i * x_i;
}
x_sum = BlockReduce(mean_storage).Reduce(x_sum, cub::Sum());
x_square_sum =
BlockReduce(variance_storeage).Reduce(x_square_sum, cub::Sum());
if (threadIdx.x == 0) {
mean_val = x_sum / inner_size;
inv_var_val =
1 / sqrt(x_square_sum / inner_size - mean_val * mean_val + epsilon);
}
}
__syncthreads();
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == framework::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
BatchNormParamType<T> dy_i =
static_cast<BatchNormParamType<T>>(dy[index]);
ds_sum +=
dy_i * (static_cast<BatchNormParamType<T>>(x[index]) - mean_val);
db_sum += dy_i;
}
ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum());
db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum());
if (threadIdx.x == 0) {
dscale_val = ds_sum * inv_var_val;
dbias_val = db_sum;
dscale[i] = dscale_val;
dbias[i] = dbias_val;
}
__syncthreads();
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == framework::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
dx[index] = scale[i] * inv_var_val *
(static_cast<BatchNormParamType<T>>(dy[index]) -
dbias_val / static_cast<BatchNormParamType<T>>(inner_size) -
(static_cast<BatchNormParamType<T>>(x[index]) - mean_val) *
inv_var_val * dscale_val / inner_size);
}
}
}
template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData(
const T *dy, const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *mean, const T *x,
const BatchNormParamType<T> *variance, const int C, const int N,
const int HxW, T *dx) {
const int outer_size = C;
const int inner_size = N * HxW;
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
......@@ -567,7 +793,6 @@ static __global__ void BNBackwardData(const T *dy,
dy_x_sub_mean_sum_val = dy_x_sub_mean_sum;
}
__syncthreads();
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == framework::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
......@@ -668,8 +893,12 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
auto dtype = platform::CudnnDataType<T>::type;
const auto *reserve_space = ctx.Input<Tensor>("ReserveSpace");
#ifdef PADDLE_WITH_HIP
// HIP do not support compute format of NHWC
auto compute_format = DataLayout::kNCHW;
auto compute_format = data_layout == DataLayout::kNHWC ? DataLayout::kNHWC
: DataLayout::kNCHW;
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC
// auto compute_format = DataLayout::kNCHW;
#else
const bool fast_nhwc_batch_norm =
dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent &&
......@@ -714,7 +943,11 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
const int num = transformed_x.numel();
#ifdef HIPCC
const int block = 256;
#else
const int block = 512;
#endif
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid1 = (num + block - 1) / block;
......@@ -734,14 +967,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
// ------------------- cudnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t data_desc_;
miopenTensorDescriptor_t bn_param_desc_;
miopenBatchNormMode_t mode_;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenCreateTensorDescriptor(&data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_));
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// miopenTensorDescriptor_t data_desc_;
// miopenTensorDescriptor_t bn_param_desc_;
// miopenBatchNormMode_t mode_;
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenCreateTensorDescriptor(&data_desc_));
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_));
#else
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
......@@ -759,7 +993,8 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#ifdef PADDLE_WITH_HIP
mode_ = miopenBNSpatial;
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// mode_ = miopenBNSpatial;
#elif CUDNN_VERSION_MIN(7, 0, 1)
if (FLAGS_cudnn_batchnorm_spatial_persistent) {
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
......@@ -771,13 +1006,14 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
#endif // CUDNN_VERSION_MIN(7, 0, 1)
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()),
const_cast<int *>(strides.data())));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenDeriveBNTensorDescriptor(bn_param_desc_,
data_desc_, mode_));
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor(
// data_desc_, CudnnDataType<T>::type,
// x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()),
// const_cast<int *>(strides.data())));
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenDeriveBNTensorDescriptor(bn_param_desc_,
// data_desc_, mode_));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
......@@ -871,20 +1107,49 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
#endif // CUDNN_VERSION_MIN(7, 4, 1)
if (!called) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenBatchNormalizationBackward(
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_,
transformed_x.template data<T>(), data_desc_,
transformed_d_y.template data<T>(), data_desc_,
transformed_d_x.template mutable_data<T>(ctx.GetPlace()),
bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, saved_mean_data, saved_var_data));
if (compute_format == DataLayout::kNCHW) {
BNBackward<
T, block,
DataLayout::kNCHW><<<grid2, block, 0, dev_ctx.stream()>>>(
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
scale->template data<BatchNormParamType<T>>(), saved_mean_data,
saved_var_data, C, N, H * W * D, epsilon,
transformed_d_x.template data<T>(),
d_scale->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()));
} else {
BNBackward<
T, block,
DataLayout::kNHWC><<<grid2, block, 0, dev_ctx.stream()>>>(
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
scale->template data<BatchNormParamType<T>>(), saved_mean_data,
saved_var_data, C, N, H * W * D, epsilon,
transformed_d_x.template data<T>(),
d_scale->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()));
}
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenBatchNormalizationBackward(
// dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
// CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
// CudnnDataType<T>::kZero(), data_desc_,
// transformed_x.template data<T>(), data_desc_,
// transformed_d_y.template data<T>(), data_desc_,
// transformed_d_x.template mutable_data<T>(ctx.GetPlace()),
// bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
// d_scale->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace()),
// d_bias->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace()),
// epsilon, saved_mean_data, saved_var_data));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnBatchNormalizationBackward(
......@@ -931,11 +1196,12 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
}
#ifdef PADDLE_WITH_HIP
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenDestroyTensorDescriptor(data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_));
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// clean when exit.
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenDestroyTensorDescriptor(data_desc_));
// PADDLE_ENFORCE_CUDA_SUCCESS(
// platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_));
#else
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS(
......
......@@ -32,6 +32,12 @@ namespace cub = hipcub;
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef __HIPCC__
#define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim)
#else
#define LAUNCH_BOUNDS(BlockDim)
#endif
namespace paddle {
namespace operators {
......@@ -58,12 +64,10 @@ using DataLayout = framework::DataLayout;
// axis=(n,h,w)))
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDX(const T *x, const T *mean,
const T *variance, const T *ddx,
const T *dy, const T *scale,
const T *ddscale, const int N, const int C,
const int sample_size, const double epsilon,
T *dx) {
__global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDX(
const T *x, const T *mean, const T *variance, const T *ddx, const T *dy,
const T *scale, const T *ddscale, const int N, const int C,
const int sample_size, const double epsilon, T *dx) {
const int outer_size = C;
const int inner_size = N * sample_size;
......@@ -160,12 +164,10 @@ __global__ void DoubleGradComputeDX(const T *x, const T *mean,
// scale * inv_var * (ddx - (x - mean) * inv_var.pow(2) *
// np.mean(ddx * (x - mean), axis=(n,h,w)))
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDDY(const T *x, const T *mean,
const T *variance, const T *ddscale,
const T *ddbias, const T *ddx,
const T *scale, const int N, const int C,
const int sample_size,
const double epsilon, T *ddy) {
__global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDDY(
const T *x, const T *mean, const T *variance, const T *ddscale,
const T *ddbias, const T *ddx, const T *scale, const int N, const int C,
const int sample_size, const double epsilon, T *ddy) {
const int outer_size = C;
const int inner_size = N * sample_size;
......@@ -238,11 +240,10 @@ __global__ void DoubleGradComputeDDY(const T *x, const T *mean,
// inv_var.pow(2) * np.mean(dy * (x-mean), axis=(n,h,w)))) *
// ddx
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDScale(const T *x, const T *mean,
const T *variance, const T *ddx,
const T *dy, const int N, const int C,
const int sample_size,
const double epsilon, T *dscale) {
__global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDScale(
const T *x, const T *mean, const T *variance, const T *ddx, const T *dy,
const int N, const int C, const int sample_size, const double epsilon,
T *dscale) {
const int outer_size = C;
const int inner_size = N * sample_size;
......@@ -302,7 +303,7 @@ __global__ void DoubleGradComputeDScale(const T *x, const T *mean,
// math: dscale = np.sum(ddx * dy, axis=(n,h,w)) * inv_var
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDScaleWithGlobal(
__global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDScaleWithGlobal(
const T *ddx, const T *variance, const T *dy, const double epsilon,
const int N, const int C, const int sample_size, T *dscale) {
int outer_size = C;
......@@ -422,8 +423,11 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
set_constant(dev_ctx, &scale_tmp, static_cast<T>(1));
}
const T *scale_data = Scale ? Scale->data<T>() : scale_tmp.data<T>();
#ifdef __HIPCC__
const int block = 256;
#else
const int block = 512;
#endif
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid = std::min(C, max_blocks);
......@@ -532,6 +536,5 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
}
}
}
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册