From 8e1b020488a8b9abd116fa0458a1bec0221ace43 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Wed, 9 Feb 2022 16:54:39 +0800 Subject: [PATCH] Optimize performance of softmax_fwd when axis!=-1 (#38602) * Optimize performence of softmax_fwd when axis!=-1 * use functor * support hip * fix functor --- paddle/fluid/operators/softmax_cudnn_op.cu.h | 165 ++++++++++++++++++- 1 file changed, 164 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.h b/paddle/fluid/operators/softmax_cudnn_op.cu.h index 236ea448f30..85bc3946bae 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.h +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.h @@ -186,6 +186,58 @@ struct UnaryDivFunctor { Tx n_inv; }; +template +struct SoftmaxForwardFunctor { + HOSTDEVICE inline SoftmaxForwardFunctor(Tx max, Tx sum) + : max(max), sum(sum) {} + + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(std::exp(x - max) / sum); + } + + private: + Tx max; + Tx sum; +}; + +template +struct SoftmaxBackwardFunctor { + HOSTDEVICE inline SoftmaxBackwardFunctor(Tx sum) : sum(sum) {} + + HOSTDEVICE inline Ty operator()(const Tx& grad_out, const Tx& out) const { + return static_cast(out * (grad_out - sum)); + } + + private: + Tx sum; +}; + +template +struct LogSoftmaxForwardFunctor { + HOSTDEVICE inline LogSoftmaxForwardFunctor(Tx max, Tx sum) + : max(max), log_sum(std::log(sum)) {} + + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(x - max - log_sum); + } + + private: + Tx max; + Tx log_sum; +}; + +template +struct LogSoftmaxBackwardFunctor { + HOSTDEVICE inline LogSoftmaxBackwardFunctor(Tx sum) : sum(sum) {} + + HOSTDEVICE inline Ty operator()(const Tx& grad_out, const Tx& out) const { + return static_cast(grad_out - std::exp(out) * sum); + } + + private: + Tx sum; +}; + /* Core function of computing softmax forward for axis=-1. The computation includes @@ -255,7 +307,8 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src, ReduceMaxFunctor(), true); WarpReduceMax(max); - // compute sum +// compute sum +#pragma unroll for (int i = 0; i < kBatchSize; ++i) { kps::ElementwiseUnary>( &srcdata[i][0][0], &srcdata[i][0][0], ExpSubFunctor(max[i])); @@ -443,6 +496,113 @@ void SwitchWarpSoftmaxBackward(const int blocks, const dim3 threads, #undef SOFTMAX_WARP_FORWARD_CASE #undef SOFTMAX_WARP_BACKWARD_CASE +/** + * + * Better performence when axis != -1 + */ + +static void GetGridDim(int high_dim, int mid_dim, int low_dim, + const dim3& block, dim3* grid) { + int device_id = paddle::platform::GetCurrentDeviceId(); + int max_mp = paddle::platform::GetGPUMultiProcessors(device_id); + int max_threads_per_mp = + paddle::platform::GetGPUMaxThreadsPerMultiProcessor(device_id); + int max_threads = max_threads_per_mp * max_mp; + int num_threads = block.x * block.y; + int max_num_blocks = max_threads / num_threads; + + int grid_x = (low_dim + block.x - 1) / block.x; + grid_x = std::min(grid_x, max_num_blocks); + int grid_y = (max_num_blocks + grid_x - 1) / grid_x; + grid_y = std::min(grid_y, high_dim); + grid->x = grid_x; + grid->y = grid_y; +} + +static void GetBlockDim(int mid_dim, int low_dim, dim3* block) { +#ifdef __HIPCC__ + constexpr int max_num_threads = 256; +#else + constexpr int max_num_threads = 1024; +#endif + int block_x = 1 << log2_ceil(low_dim); + int block_y = 1 << log2_ceil(mid_dim); + block->x = std::min(block_x, 32); + block->y = std::min(block_y, static_cast(max_num_threads / block->x)); + block->x = std::min(block_x, static_cast(max_num_threads / block->y)); +} + +static void GetLaunchConfig(int high_dim, int mid_dim, int low_dim, dim3* grid, + dim3* block) { + GetBlockDim(mid_dim, low_dim, block); + GetGridDim(high_dim, mid_dim, low_dim, *block, grid); +} + +template class Functor> +__global__ void NormalSoftmaxForward(T* output, const T* input, int high_dim, + int mid_dim, int low_dim) { + using kMode = kps::details::ReduceMode; + const int high_stride = mid_dim * low_dim; + const int mid_stride = low_dim; + for (int high_id = blockIdx.y; high_id < high_dim; high_id += gridDim.y) { + for (int low_id = blockIdx.x * blockDim.x + threadIdx.x; low_id < low_dim; + low_id += blockDim.x * gridDim.x) { + const int input_offset = high_id * high_stride + low_id; + + // 1. reduce max + AccT max_value = -std::numeric_limits::infinity(); + AccT value = -std::numeric_limits::infinity(); + for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) { + value = static_cast(input[input_offset + mid_id * mid_stride]); + max_value = kps::MaxFunctor()(max_value, value); + } + + if (blockDim.y > 1) { + kps::Reduce, kMode::kGlobalMode>( + &max_value, &max_value, kps::MaxFunctor(), false); + } + + // 2. reduce sum + AccT sum = 0; + for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) { + value = static_cast(input[input_offset + mid_id * mid_stride]); + sum += std::exp(value - max_value); + } + if (blockDim.y > 1) { + kps::Reduce, kMode::kGlobalMode>( + &sum, &sum, kps::AddFunctor(), false); + } + + // 3. (log)softmax + Functor functor(max_value, sum); + for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) { + int data_offset = input_offset + mid_id * mid_stride; + output[data_offset] = functor(static_cast(input[data_offset])); + } + } + } +} + +template +void LaunchNormalSoftmaxForward(const platform::CUDADeviceContext& dev_ctx, + T* output_data, const T* input_data, + int high_dim, int mid_dim, int low_dim) { + using AccT = typename details::MPTypeTrait::Type; + dim3 grid, block; + GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block); + if (LogMode) { + NormalSoftmaxForward< + T, AccT, + LogSoftmaxForwardFunctor><<>>( + output_data, input_data, high_dim, mid_dim, low_dim); + } else { + NormalSoftmaxForward< + T, AccT, SoftmaxForwardFunctor><<>>( + output_data, input_data, high_dim, mid_dim, low_dim); + } +} + template void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, const Tensor& x, const int input_axis, @@ -490,6 +650,9 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, out_data, x.data(), N, dim, dim, kDimLog2); } + } else if (D > 1) { + LaunchNormalSoftmaxForward(dev_ctx, out_data, x.data(), N, + dim, D); } else { ScopedTensorDescriptor desc; std::vector tensor_dims = {N, dim, D, 1}; -- GitLab