From 9012787f95004409d721dcfc40a95b6698234a3a Mon Sep 17 00:00:00 2001 From: carryyu <569782149@qq.com> Date: Thu, 29 Sep 2022 20:45:56 +0800 Subject: [PATCH] Optimize softmax's performance when dim_size >= 100000. (#46535) --- paddle/phi/kernels/gpudnn/softmax_gpudnn.h | 236 ++++++++++++++++++++- 1 file changed, 235 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h index a15244cc52..75dfc8514a 100644 --- a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h +++ b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h" @@ -26,6 +27,32 @@ limitations under the License. */ #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#define MATRIX_SOFTMAX_ALIGN_BYTES 16 +#define MATRIX_SOFTMAX_THREAHOLD 100000 + +#define FIXED_BLOCK_DIM_BASE(dim, ...) \ + case (dim): { \ + constexpr auto kBlockDim = (dim); \ + __VA_ARGS__; \ + } break + +#define FIXED_VEC_SIZE_BASE(vec_size, ...) \ + case (vec_size): { \ + constexpr auto VecSize = (vec_size); \ + __VA_ARGS__; \ + } break + +#define FIXED_BLOCK_DIM(...) \ + FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) + +#define FIXED_VEC_SIZE(...) \ + FIXED_VEC_SIZE_BASE(8, ##__VA_ARGS__); \ + FIXED_VEC_SIZE_BASE(4, ##__VA_ARGS__) + namespace phi { using ScopedTensorDescriptor = paddle::platform::ScopedTensorDescriptor; @@ -85,6 +112,20 @@ static inline int Log2Ceil(int value) { return log2_value; } +inline int getBlockSize(int vec_size, uint64_t dim_size) { + uint64_t block_size = 1; + uint64_t max_block_size = + std::min(dim_size / vec_size, static_cast(1024)); + + if (vec_size > 1) { + max_block_size /= 2; + } + + while (block_size < (max_block_size)) block_size *= 2; + block_size = std::max(block_size, static_cast(32)); + return block_size; +} + template __device__ __forceinline__ void WarpReduceSum(T* sum) { #pragma unroll @@ -111,6 +152,41 @@ __device__ __forceinline__ void WarpReduceMax(T* sum) { } } +template +__inline__ __device__ void BlockReduceMax(T* val) { + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + WarpReduceMax(val); + + if (lane == 0) shared[wid] = *val; + + __syncthreads(); + + int block_span = (blockDim.x + warpSize - 1) >> 5; + *val = (lane < block_span) ? shared[lane] : -1e10f; + WarpReduceMax(val); +} + +template +__inline__ __device__ void BlockReduceSum(T* val) { + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + WarpReduceSum(val); + + __syncthreads(); + if (lane == 0) shared[wid] = *val; + + __syncthreads(); + + int block_span = (blockDim.x + warpSize - 1) >> 5; + *val = (lane < block_span) ? shared[lane] : static_cast(0.0f); + WarpReduceSum(val); +} + template struct ReduceMaxFunctor { inline Ty initial() { return -std::numeric_limits::infinity(); } @@ -120,6 +196,14 @@ struct ReduceMaxFunctor { } }; +template +struct MaxFunctor { + __device__ __forceinline__ AccT operator()(const AccT& max_v, + const T& v) const { + return max(max_v, static_cast(v)); + } +}; + template struct ExpFunctor { HOSTDEVICE inline Ty operator()(const Tx& x) const { @@ -245,6 +329,126 @@ struct LogSoftmaxBackwardFunctor { Tx sum; }; +template +struct SumExpFunctor { + HOSTDEVICE inline SumExpFunctor(AccT v) : max_v(v) {} + + HOSTDEVICE inline AccT operator()(AccT sum, T v) const { + return sum + std::exp(static_cast(v) - max_v); + } + + private: + AccT max_v; +}; + +template