From f5a3b4272c5f6b6bf2dec93996392f19d5e3fa91 Mon Sep 17 00:00:00 2001 From: Shaojie WANG Date: Tue, 30 May 2023 06:04:25 -0700 Subject: [PATCH] softmax fwd: force vec size to 1 when dtype is float (#54183) * softmax fwd: force vec size to 1 when dtype is float * use 1024 as threshold to use cudnn --- paddle/phi/kernels/gpudnn/softmax_gpudnn.h | 58 ++++++++++++++-------- 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h index 9db814f1986..fb434b5c9cf 100644 --- a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h +++ b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h @@ -798,6 +798,7 @@ void SwitchWarpSoftmaxForward(const IndexType blocks, SOFTMAX_WARP_FORWARD_CASE(7, AccT); SOFTMAX_WARP_FORWARD_CASE(8, AccT); SOFTMAX_WARP_FORWARD_CASE(9, AccT); + SOFTMAX_WARP_FORWARD_CASE(10, AccT); default: break; } @@ -836,6 +837,7 @@ void SwitchWarpSoftmaxBackward(const int blocks, SOFTMAX_WARP_BACKWARD_CASE(7, AccT); SOFTMAX_WARP_BACKWARD_CASE(8, AccT); SOFTMAX_WARP_BACKWARD_CASE(9, AccT); + SOFTMAX_WARP_BACKWARD_CASE(10, AccT); default: break; } @@ -1262,7 +1264,7 @@ bool UseCudnnSoftmax(const GPUContext& ctx, #endif } } - constexpr int max_dim = 512; + constexpr int max_dim = 1024; if (!cudnn_available || !last_dim || (softmax_dim <= max_dim && sizeof(T) <= 4) || softmax_dim >= MATRIX_SOFTMAX_THREAHOLD) { @@ -1311,27 +1313,7 @@ void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx, using T4 = typename VecT4::Type; using T2 = typename VecT2::Type; - if (dim % 4 == 0) { - SwitchWarpSoftmaxForward(blocks, - threads, - dev_ctx, - out_data, - x.data(), - N, - dim, - dim, - dim_log2); - } else if (dim % 2 == 0) { - SwitchWarpSoftmaxForward(blocks, - threads, - dev_ctx, - out_data, - x.data(), - N, - dim, - dim, - dim_log2); - } else { + if (std::is_same::value) { SwitchWarpSoftmaxForward(blocks, threads, dev_ctx, @@ -1341,6 +1323,38 @@ void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx, dim, dim, dim_log2); + } else { + if (dim % 4 == 0) { + SwitchWarpSoftmaxForward(blocks, + threads, + dev_ctx, + out_data, + x.data(), + N, + dim, + dim, + dim_log2); + } else if (dim % 2 == 0) { + SwitchWarpSoftmaxForward(blocks, + threads, + dev_ctx, + out_data, + x.data(), + N, + dim, + dim, + dim_log2); + } else { + SwitchWarpSoftmaxForward(blocks, + threads, + dev_ctx, + out_data, + x.data(), + N, + dim, + dim, + dim_log2); + } } } else { LaunchSoftmaxForwardCudnnKernel(dev_ctx, x, axis, LogMode, out); -- GitLab