未验证 提交 bd5e97d3 编写于 作者: Z Zhang Ting 提交者: GitHub

slice large tensor for cudnn_softmax (#43681)

上级 827d9992
...@@ -772,15 +772,12 @@ static std::vector<int> GetSoftmaxTensorDims(const phi::DDim& dims, ...@@ -772,15 +772,12 @@ static std::vector<int> GetSoftmaxTensorDims(const phi::DDim& dims,
template <typename T> template <typename T>
void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx, void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& x, const T* x_data,
const int axis, const int axis,
const int rank,
const bool log_mode, const bool log_mode,
DenseTensor* out) { const std::vector<int>& tensor_dims,
auto* out_data = out->data<T>(); T* out_data) {
const int rank = x.dims().size();
std::vector<int> tensor_dims = GetSoftmaxTensorDims(x.dims(), axis);
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW; GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;
...@@ -795,7 +792,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx, ...@@ -795,7 +792,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
handle, handle,
paddle::platform::CudnnDataType<T>::kOne(), paddle::platform::CudnnDataType<T>::kOne(),
desc, desc,
x.data<T>(), x_data,
paddle::platform::CudnnDataType<T>::kZero(), paddle::platform::CudnnDataType<T>::kZero(),
desc, desc,
out_data, out_data,
...@@ -812,25 +809,47 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx, ...@@ -812,25 +809,47 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
mode, mode,
paddle::platform::CudnnDataType<T>::kOne(), paddle::platform::CudnnDataType<T>::kOne(),
desc, desc,
x.data<T>(), x_data,
paddle::platform::CudnnDataType<T>::kZero(), paddle::platform::CudnnDataType<T>::kZero(),
desc, desc,
out_data)); out_data));
#endif #endif
} }
template <typename T>
void LaunchSoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& x,
const int axis,
const bool log_mode,
DenseTensor* out) {
auto* out_data = out->data<T>();
auto* x_data = x.data<T>();
const int rank = x.dims().size();
std::vector<int> tensor_dims = GetSoftmaxTensorDims(x.dims(), axis);
int64_t remaining = tensor_dims[0];
int dim = tensor_dims[1];
int64_t batch_size = std::numeric_limits<int32_t>::max() / dim;
int offset = batch_size * dim;
while (remaining > 0) {
tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
SoftmaxForwardCudnnKernel<T>(
dev_ctx, x_data, axis, rank, log_mode, tensor_dims, out_data);
x_data += offset;
out_data += offset;
remaining -= batch_size;
}
}
template <typename T> template <typename T>
void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx, void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& out, const T* out_data,
const DenseTensor& dout, const T* dout_data,
const int axis, const int axis,
const int rank,
const bool log_mode, const bool log_mode,
DenseTensor* dx) { const std::vector<int>& tensor_dims,
auto* dx_data = dx->data<T>(); T* dx_data) {
int rank = out.dims().size();
std::vector<int> tensor_dims = GetSoftmaxTensorDims(out.dims(), axis);
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW; GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;
...@@ -846,9 +865,9 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx, ...@@ -846,9 +865,9 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
handle, handle,
paddle::platform::CudnnDataType<T>::kOne(), paddle::platform::CudnnDataType<T>::kOne(),
desc, desc,
out.data<T>(), out_data,
desc, desc,
dout.data<T>(), dout_data,
paddle::platform::CudnnDataType<T>::kZero(), paddle::platform::CudnnDataType<T>::kZero(),
desc, desc,
dx_data, dx_data,
...@@ -865,18 +884,52 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx, ...@@ -865,18 +884,52 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
mode, mode,
paddle::platform::CudnnDataType<T>::kOne(), paddle::platform::CudnnDataType<T>::kOne(),
desc, desc,
out.data<T>(), out_data,
desc, desc,
dout.data<T>(), dout_data,
paddle::platform::CudnnDataType<T>::kZero(), paddle::platform::CudnnDataType<T>::kZero(),
desc, desc,
dx_data)); dx_data));
#endif #endif
} }
template <typename T>
void LaunchSoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
const int axis,
const bool log_mode,
DenseTensor* dx) {
auto* dx_data = dx->data<T>();
auto* out_data = out.data<T>();
auto* dout_data = dout.data<T>();
int rank = out.dims().size();
std::vector<int> tensor_dims = GetSoftmaxTensorDims(out.dims(), axis);
int64_t remaining = tensor_dims[0];
int dim = tensor_dims[1];
int64_t batch_size = std::numeric_limits<int32_t>::max() / dim;
int offset = batch_size * dim;
while (remaining > 0) {
tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
SoftmaxBackwardCudnnKernel<T>(dev_ctx,
out_data,
dout_data,
axis,
rank,
log_mode,
tensor_dims,
dx_data);
out_data += offset;
dout_data += offset;
dx_data += offset;
remaining -= batch_size;
}
}
#if CUDNN_VERSION < 8100 #if CUDNN_VERSION < 8100
template <> template <>
inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>( inline void LaunchSoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx, const GPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const int axis, const int axis,
...@@ -887,7 +940,7 @@ inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>( ...@@ -887,7 +940,7 @@ inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
"8100.")); "8100."));
} }
template <> template <>
inline void SoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>( inline void LaunchSoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx, const GPUContext& dev_ctx,
const DenseTensor& out, const DenseTensor& out,
const DenseTensor& dout, const DenseTensor& dout,
...@@ -933,60 +986,62 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, ...@@ -933,60 +986,62 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
int dim = tensor_dims[1]; int dim = tensor_dims[1];
int D = tensor_dims[2]; int D = tensor_dims[2];
if (D == 1 && !UseCudnnSoftmax<T>(dev_ctx, dim, true)) { if (D == 1) {
int dim_log2 = static_cast<int>(Log2Ceil(dim)); if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
int dim_ceil = 1 << dim_log2; int dim_log2 = static_cast<int>(Log2Ceil(dim));
int warp_size = (dim_ceil < 32) ? dim_ceil : 32; int dim_ceil = 1 << dim_log2;
int batches_per_warp = (dim_ceil <= 32) ? 2 : 1; int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
int batches_per_warp = (dim_ceil <= 32) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128; // use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int warps_per_block = (threads_per_block / warp_size);
int blocks = (N + batches_per_block - 1) / batches_per_block; int batches_per_block = warps_per_block * batches_per_warp;
dim3 threads(warp_size, warps_per_block, 1); int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// vectorization read/write
using T4 = typename VecT4<T>::Type; // vectorization read/write
using T2 = typename VecT2<T>::Type; using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type;
if (dim % 4 == 0) {
SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks, if (dim % 4 == 0) {
threads, SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks,
dev_ctx, threads,
out_data, dev_ctx,
x.data<T>(), out_data,
N, x.data<T>(),
dim, N,
dim, dim,
dim_log2); dim,
} else if (dim % 2 == 0) { dim_log2);
SwitchWarpSoftmaxForward<T, T2, LogMode>(blocks, } else if (dim % 2 == 0) {
threads, SwitchWarpSoftmaxForward<T, T2, LogMode>(blocks,
dev_ctx, threads,
out_data, dev_ctx,
x.data<T>(), out_data,
N, x.data<T>(),
dim, N,
dim, dim,
dim_log2); dim,
dim_log2);
} else {
SwitchWarpSoftmaxForward<T, T, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
}
} else { } else {
SwitchWarpSoftmaxForward<T, T, LogMode>(blocks, LaunchSoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} }
} else if (D > 1) { } else {
LaunchNormalSoftmaxForward<T, LogMode>( LaunchNormalSoftmaxForward<T, LogMode>(
dev_ctx, out_data, x.data<T>(), N, dim, D); dev_ctx, out_data, x.data<T>(), N, dim, D);
} else {
SoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
} }
} }
...@@ -1005,61 +1060,64 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx, ...@@ -1005,61 +1060,64 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
int dim = tensor_dims[1]; int dim = tensor_dims[1];
int D = tensor_dims[2]; int D = tensor_dims[2];
if (D == 1 && !UseCudnnSoftmax<T>(dev_ctx, dim, true)) { if (D == 1) {
int dim_log2 = Log2Ceil(dim); if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
int dim_ceil = 1 << dim_log2; int dim_log2 = Log2Ceil(dim);
int warp_size = (dim_ceil < 32) ? dim_ceil : 32; int dim_ceil = 1 << dim_log2;
int batches_per_warp = (dim_ceil <= 128) ? 2 : 1; int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
int batches_per_warp = (dim_ceil <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int warps_per_block = (threads_per_block / warp_size);
int blocks = (N + batches_per_block - 1) / batches_per_block; int batches_per_block = warps_per_block * batches_per_warp;
dim3 threads(warp_size, warps_per_block, 1); int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// vectorization read/write
using T4 = typename VecT4<T>::Type; // vectorization read/write
using T2 = typename VecT2<T>::Type; using T4 = typename VecT4<T>::Type;
if (dim % 4 == 0) { using T2 = typename VecT2<T>::Type;
SwitchWarpSoftmaxBackward<T, T4, LogMode>(blocks, if (dim % 4 == 0) {
threads, SwitchWarpSoftmaxBackward<T, T4, LogMode>(blocks,
dev_ctx, threads,
dx_data, dev_ctx,
dout.data<T>(), dx_data,
out.data<T>(), dout.data<T>(),
N, out.data<T>(),
dim, N,
dim, dim,
dim_log2); dim,
} else if (dim % 2 == 0) { dim_log2);
SwitchWarpSoftmaxBackward<T, T2, LogMode>(blocks, } else if (dim % 2 == 0) {
threads, SwitchWarpSoftmaxBackward<T, T2, LogMode>(blocks,
dev_ctx, threads,
dx_data, dev_ctx,
dout.data<T>(), dx_data,
out.data<T>(), dout.data<T>(),
N, out.data<T>(),
dim, N,
dim, dim,
dim_log2); dim,
dim_log2);
} else {
SwitchWarpSoftmaxBackward<T, T, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
}
} else { } else {
SwitchWarpSoftmaxBackward<T, T, LogMode>(blocks, LaunchSoftmaxBackwardCudnnKernel<T>(
threads, dev_ctx, out, dout, axis, LogMode, dx);
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
} }
} else if (D > 1) { } else {
LaunchNormalSoftmaxBackward<T, LogMode>( LaunchNormalSoftmaxBackward<T, LogMode>(
dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, dim, D); dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, dim, D);
} else {
SoftmaxBackwardCudnnKernel<T>(dev_ctx, out, dout, axis, LogMode, dx);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册