未验证 提交 bd5e97d3 编写于 作者: Z Zhang Ting 提交者: GitHub

slice large tensor for cudnn_softmax (#43681)

上级 827d9992
...@@ -772,15 +772,12 @@ static std::vector<int> GetSoftmaxTensorDims(const phi::DDim& dims, ...@@ -772,15 +772,12 @@ static std::vector<int> GetSoftmaxTensorDims(const phi::DDim& dims,
template <typename T> template <typename T>
void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx, void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& x, const T* x_data,
const int axis, const int axis,
const int rank,
const bool log_mode, const bool log_mode,
DenseTensor* out) { const std::vector<int>& tensor_dims,
auto* out_data = out->data<T>(); T* out_data) {
const int rank = x.dims().size();
std::vector<int> tensor_dims = GetSoftmaxTensorDims(x.dims(), axis);
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW; GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;
...@@ -795,7 +792,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx, ...@@ -795,7 +792,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
handle, handle,
paddle::platform::CudnnDataType<T>::kOne(), paddle::platform::CudnnDataType<T>::kOne(),
desc, desc,
x.data<T>(), x_data,
paddle::platform::CudnnDataType<T>::kZero(), paddle::platform::CudnnDataType<T>::kZero(),
desc, desc,
out_data, out_data,
...@@ -812,7 +809,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx, ...@@ -812,7 +809,7 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
mode, mode,
paddle::platform::CudnnDataType<T>::kOne(), paddle::platform::CudnnDataType<T>::kOne(),
desc, desc,
x.data<T>(), x_data,
paddle::platform::CudnnDataType<T>::kZero(), paddle::platform::CudnnDataType<T>::kZero(),
desc, desc,
out_data)); out_data));
...@@ -820,17 +817,39 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx, ...@@ -820,17 +817,39 @@ void SoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
} }
template <typename T> template <typename T>
void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx, void LaunchSoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& out, const DenseTensor& x,
const DenseTensor& dout,
const int axis, const int axis,
const bool log_mode, const bool log_mode,
DenseTensor* dx) { DenseTensor* out) {
auto* dx_data = dx->data<T>(); auto* out_data = out->data<T>();
auto* x_data = x.data<T>();
const int rank = x.dims().size();
int rank = out.dims().size(); std::vector<int> tensor_dims = GetSoftmaxTensorDims(x.dims(), axis);
std::vector<int> tensor_dims = GetSoftmaxTensorDims(out.dims(), axis); int64_t remaining = tensor_dims[0];
int dim = tensor_dims[1];
int64_t batch_size = std::numeric_limits<int32_t>::max() / dim;
int offset = batch_size * dim;
while (remaining > 0) {
tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
SoftmaxForwardCudnnKernel<T>(
dev_ctx, x_data, axis, rank, log_mode, tensor_dims, out_data);
x_data += offset;
out_data += offset;
remaining -= batch_size;
}
}
template <typename T>
void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
const T* out_data,
const T* dout_data,
const int axis,
const int rank,
const bool log_mode,
const std::vector<int>& tensor_dims,
T* dx_data) {
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW; GPUDNNDataLayout layout = GPUDNNDataLayout::kNCHW;
...@@ -846,9 +865,9 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx, ...@@ -846,9 +865,9 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
handle, handle,
paddle::platform::CudnnDataType<T>::kOne(), paddle::platform::CudnnDataType<T>::kOne(),
desc, desc,
out.data<T>(), out_data,
desc, desc,
dout.data<T>(), dout_data,
paddle::platform::CudnnDataType<T>::kZero(), paddle::platform::CudnnDataType<T>::kZero(),
desc, desc,
dx_data, dx_data,
...@@ -865,18 +884,52 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx, ...@@ -865,18 +884,52 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
mode, mode,
paddle::platform::CudnnDataType<T>::kOne(), paddle::platform::CudnnDataType<T>::kOne(),
desc, desc,
out.data<T>(), out_data,
desc, desc,
dout.data<T>(), dout_data,
paddle::platform::CudnnDataType<T>::kZero(), paddle::platform::CudnnDataType<T>::kZero(),
desc, desc,
dx_data)); dx_data));
#endif #endif
} }
template <typename T>
void LaunchSoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
const int axis,
const bool log_mode,
DenseTensor* dx) {
auto* dx_data = dx->data<T>();
auto* out_data = out.data<T>();
auto* dout_data = dout.data<T>();
int rank = out.dims().size();
std::vector<int> tensor_dims = GetSoftmaxTensorDims(out.dims(), axis);
int64_t remaining = tensor_dims[0];
int dim = tensor_dims[1];
int64_t batch_size = std::numeric_limits<int32_t>::max() / dim;
int offset = batch_size * dim;
while (remaining > 0) {
tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
SoftmaxBackwardCudnnKernel<T>(dev_ctx,
out_data,
dout_data,
axis,
rank,
log_mode,
tensor_dims,
dx_data);
out_data += offset;
dout_data += offset;
dx_data += offset;
remaining -= batch_size;
}
}
#if CUDNN_VERSION < 8100 #if CUDNN_VERSION < 8100
template <> template <>
inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>( inline void LaunchSoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx, const GPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const int axis, const int axis,
...@@ -887,7 +940,7 @@ inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>( ...@@ -887,7 +940,7 @@ inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
"8100.")); "8100."));
} }
template <> template <>
inline void SoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>( inline void LaunchSoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx, const GPUContext& dev_ctx,
const DenseTensor& out, const DenseTensor& out,
const DenseTensor& dout, const DenseTensor& dout,
...@@ -933,7 +986,8 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, ...@@ -933,7 +986,8 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
int dim = tensor_dims[1]; int dim = tensor_dims[1];
int D = tensor_dims[2]; int D = tensor_dims[2];
if (D == 1 && !UseCudnnSoftmax<T>(dev_ctx, dim, true)) { if (D == 1) {
if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
int dim_log2 = static_cast<int>(Log2Ceil(dim)); int dim_log2 = static_cast<int>(Log2Ceil(dim));
int dim_ceil = 1 << dim_log2; int dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32; int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
...@@ -982,11 +1036,12 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, ...@@ -982,11 +1036,12 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
dim, dim,
dim_log2); dim_log2);
} }
} else if (D > 1) { } else {
LaunchSoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
}
} else {
LaunchNormalSoftmaxForward<T, LogMode>( LaunchNormalSoftmaxForward<T, LogMode>(
dev_ctx, out_data, x.data<T>(), N, dim, D); dev_ctx, out_data, x.data<T>(), N, dim, D);
} else {
SoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
} }
} }
...@@ -1005,7 +1060,8 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx, ...@@ -1005,7 +1060,8 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
int dim = tensor_dims[1]; int dim = tensor_dims[1];
int D = tensor_dims[2]; int D = tensor_dims[2];
if (D == 1 && !UseCudnnSoftmax<T>(dev_ctx, dim, true)) { if (D == 1) {
if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
int dim_log2 = Log2Ceil(dim); int dim_log2 = Log2Ceil(dim);
int dim_ceil = 1 << dim_log2; int dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32; int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
...@@ -1055,11 +1111,13 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx, ...@@ -1055,11 +1111,13 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
dim, dim,
dim_log2); dim_log2);
} }
} else if (D > 1) { } else {
LaunchSoftmaxBackwardCudnnKernel<T>(
dev_ctx, out, dout, axis, LogMode, dx);
}
} else {
LaunchNormalSoftmaxBackward<T, LogMode>( LaunchNormalSoftmaxBackward<T, LogMode>(
dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, dim, D); dev_ctx, dx_data, dout.data<T>(), out.data<T>(), N, dim, D);
} else {
SoftmaxBackwardCudnnKernel<T>(dev_ctx, out, dout, axis, LogMode, dx);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册