未验证 提交 bd5e97d3 编写于 作者: Z Zhang Ting 提交者: GitHub

slice large tensor for cudnn_softmax (#43681)

上级 827d9992
......@@ -772,15 +772,12 @@ static std::vector<int> GetSoftmaxTensorDims(const phi::DDim& dims,
template <typename T>
void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& x,
const T* x_data,
const int axis,
const int rank,
const bool log_mode,
DenseTensor* out) {
auto* out_data = out->data<T>();
const int rank = x.dims().size();
std::vector<int> tensor_dims = GetSoftmaxTensorDims(x.dims(), axis);
const std::vector<int>& tensor_dims,
T* out_data) {
auto handle = dev_ctx.cudnn_handle();
GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;
......@@ -795,7 +792,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
handle,
paddle::platform::CudnnDataType<T>::kOne(),
desc,
x.data<T>(),
x_data,
paddle::platform::CudnnDataType<T>::kZero(),
desc,
out_data,
......@@ -812,25 +809,47 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
mode,
paddle::platform::CudnnDataType<T>::kOne(),
desc,
x.data<T>(),
x_data,
paddle::platform::CudnnDataType<T>::kZero(),
desc,
out_data));
#endif
}
template <typename T>
void LaunchSoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& x,
const int axis,
const bool log_mode,
DenseTensor* out) {
auto* out_data = out->data<T>();
auto* x_data = x.data<T>();
const int rank = x.dims().size();
std::vector<int> tensor_dims = GetSoftmaxTensorDims(x.dims(), axis);
int64_t remaining = tensor_dims[0];
int dim = tensor_dims[1];
int64_t batch_size = std::numeric_limits<int32_t>::max() / dim;
int offset = batch_size * dim;
while (remaining > 0) {
tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
SoftmaxForwardCudnnKernel<T>(
dev_ctx, x_data, axis, rank, log_mode, tensor_dims, out_data);
x_data += offset;
out_data += offset;
remaining -= batch_size;
}
}
template <typename T>
void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
const T* out_data,
const T* dout_data,
const int axis,
const int rank,
const bool log_mode,
DenseTensor* dx) {
auto* dx_data = dx->data<T>();
int rank = out.dims().size();
std::vector<int> tensor_dims = GetSoftmaxTensorDims(out.dims(), axis);
const std::vector<int>& tensor_dims,
T* dx_data) {
auto handle = dev_ctx.cudnn_handle();
GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;
......@@ -846,9 +865,9 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
handle,
paddle::platform::CudnnDataType<T>::kOne(),
desc,
out.data<T>(),
out_data,
desc,
dout.data<T>(),
dout_data,
paddle::platform::CudnnDataType<T>::kZero(),
desc,
dx_data,
......@@ -865,18 +884,52 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
mode,
paddle::platform::CudnnDataType<T>::kOne(),
desc,
out.data<T>(),
out_data,
desc,
dout.data<T>(),
dout_data,
paddle::platform::CudnnDataType<T>::kZero(),
desc,
dx_data));
#endif
}
template <typename T>
void LaunchSoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
const int axis,
const bool log_mode,
DenseTensor* dx) {
auto* dx_data = dx->data<T>();
auto* out_data = out.data<T>();
auto* dout_data = dout.data<T>();
int rank = out.dims().size();
std::vector<int> tensor_dims = GetSoftmaxTensorDims(out.dims(), axis);
int64_t remaining = tensor_dims[0];
int dim = tensor_dims[1];
int64_t batch_size = std::numeric_limits<int32_t>::max() / dim;
int offset = batch_size * dim;
while (remaining > 0) {
tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
SoftmaxBackwardCudnnKernel<T>(dev_ctx,
out_data,
dout_data,
axis,
rank,
log_mode,
tensor_dims,
dx_data);
out_data += offset;
dout_data += offset;
dx_data += offset;
remaining -= batch_size;
}
}
#if CUDNN_VERSION < 8100
template <>
inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
inline void LaunchSoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx,
const DenseTensor& x,
const int axis,
......@@ -887,7 +940,7 @@ inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
"8100."));
}
template <>
inline void SoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
inline void LaunchSoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
......@@ -933,60 +986,62 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
int dim = tensor_dims[1];
int D = tensor_dims[2];
if (D == 1 && !UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
int dim_log2 = static_cast<int>(Log2Ceil(dim));
int dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
int batches_per_warp = (dim_ceil <= 32) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// vectorization read/write
using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type;
if (dim % 4 == 0) {
SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxForward<T, T2, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
if (D == 1) {
if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
int dim_log2 = static_cast<int>(Log2Ceil(dim));
int dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
int batches_per_warp = (dim_ceil <= 32) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// vectorization read/write
using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type;
if (dim % 4 == 0) {
SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxForward<T, T2, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else {
SwitchWarpSoftmaxForward<T, T, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
}
} else {
SwitchWarpSoftmaxForward<T, T, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
LaunchSoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
}
} else if (D > 1) {
} else {
LaunchNormalSoftmaxForward<T, LogMode>(
dev_ctx, out_data, x.data<T>(), N, dim, D);
} else {
SoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
}
}
......@@ -1005,61 +1060,64 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
int dim = tensor_dims[1];
int D = tensor_dims[2];
if (D == 1 && !UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
int dim_log2 = Log2Ceil(dim);
int dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
int batches_per_warp = (dim_ceil <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// vectorization read/write
using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type;
if (dim % 4 == 0) {
SwitchWarpSoftmaxBackward<T, T4, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxBackward<T, T2, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
if (D == 1) {
if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
int dim_log2 = Log2Ceil(dim);
int dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
int batches_per_warp = (dim_ceil <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// vectorization read/write
using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type;
if (dim % 4 == 0) {
SwitchWarpSoftmaxBackward<T, T4, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxBackward<T, T2, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
} else {
SwitchWarpSoftmaxBackward<T, T, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
}
} else {
SwitchWarpSoftmaxBackward<T, T, LogMode>(blocks,
threads,
dev_ctx,
dx_data,
dout.data<T>(),
out.data<T>(),
N,
dim,
dim,
dim_log2);
LaunchSoftmaxBackwardCudnnKernel<T>(
dev_ctx, out, dout, axis, LogMode, dx);
}
} else if (D > 1) {
} else {
LaunchNormalSoftmaxBackward<T, LogMode>(
dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, dim, D);
} else {
SoftmaxBackwardCudnnKernel<T>(dev_ctx, out, dout, axis, LogMode, dx);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册