diff --git a/paddle/fluid/operators/math/bert_encoder_functor.cu b/paddle/fluid/operators/math/bert_encoder_functor.cu index 5d47c066d6eac7bc7b6f1888cb688768cbdfda94..9274146290d5f3be7cf1a67a53267d2e82c82ee8 100644 --- a/paddle/fluid/operators/math/bert_encoder_functor.cu +++ b/paddle/fluid/operators/math/bert_encoder_functor.cu @@ -143,30 +143,42 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_, const T *bias_qk_, int qk_offset = blockIdx.x * seq_len; assert(blockDim.x % 32 == 0); - __shared__ float s_sum, s_max; - - float qk = threadIdx.x < seq_len - ? static_cast((qk_buf_[threadIdx.x + qk_offset] + - bias_qk_[threadIdx.x + qk_offset])) - : 0.0f; - float tmp = threadIdx.x < seq_len ? static_cast(qk) : -1e20f; - + float tmp = threadIdx.x < seq_len + ? static_cast(qk_buf_[threadIdx.x + qk_offset] + + bias_qk_[threadIdx.x + qk_offset]) + : -1e20f; float max_val = blockReduceMax(tmp, mask); - if (threadIdx.x == 0) s_max = max_val; - __syncthreads(); - - float qk_tmp = - threadIdx.x < seq_len ? __expf(static_cast(tmp - s_max)) : 0.0f; + float qk_tmp = threadIdx.x < seq_len ? __expf(tmp - max_val) : 0.0f; float sum_val = blockReduceSum(qk_tmp, mask); - if (threadIdx.x == 0) { - s_sum = sum_val + 1e-6f; - } - __syncthreads(); - if (threadIdx.x < seq_len) - qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / s_sum); + qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / sum_val); +} + +template +__global__ void SoftmaxKernelWithEltadd2(T *qk_buf_, const T *bias_qk_, + const int batch_size, + const int head_num, const int seq_len, + const unsigned mask) { + int qk_offset = blockIdx.x * seq_len; + int idx = threadIdx.x; + assert(blockDim.x % 32 == 0); + + float2 tmp = + idx < seq_len + ? ToFloat2(qk_buf_[idx + qk_offset] + bias_qk_[idx + qk_offset]) + : make_float2(-1e20f, -1e20f); + float max_val = blockReduceMax(max(tmp.x, tmp.y), mask); + float2 qk_tmp = idx < seq_len ? make_float2(__expf(tmp.x - max_val), + __expf(tmp.y - max_val)) + : make_float2(0.f, 0.f); + float sum_val = blockReduceSum(qk_tmp.x + qk_tmp.y, mask) + 1e-6f; + + if (idx < seq_len) { + qk_buf_[idx + qk_offset] = + FloatsToPair(qk_tmp.x / sum_val, qk_tmp.y / sum_val); + } } template @@ -198,21 +210,28 @@ inline void MatMulWithHeadQK(const platform::CUDADeviceContext &context, "seq_len should <= 1024, " "but received seq_len is:%d", seq_len)); - if (seq_len <= 32) - block = 32; - else if (seq_len > 32 && seq_len <= 64) - block = 64; - else if (seq_len > 64 && seq_len <= 128) - block = 128; - else if (seq_len > 128 && seq_len <= 256) - block = 256; - else if (seq_len > 256 && seq_len <= 512) - block = 512; - else - block = 1024; - - SoftmaxKernelWithEltadd<<>>( - qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK); + if (seq_len % 2 == 0) { + block = (seq_len <= 64) ? 32 : ((seq_len + 63) / 64) * 32; +#ifdef SUPPORTS_CUDA_FP16 + if (std::is_same::value) { +#endif + SoftmaxKernelWithEltadd2<<>>( + reinterpret_cast(qk_buf_), + reinterpret_cast(bias_qk), batch_size, head_num, + seq_len / 2, FINAL_MASK); +#ifdef SUPPORTS_CUDA_FP16 + } else { + SoftmaxKernelWithEltadd2<__half2><<>>( + reinterpret_cast<__half2 *>(qk_buf_), + reinterpret_cast(bias_qk), batch_size, head_num, + seq_len / 2, FINAL_MASK); + } +#endif + } else { + block = (seq_len <= 32) ? 32 : ((seq_len + 31) / 32) * 32; + SoftmaxKernelWithEltadd<<>>( + qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK); + } } template diff --git a/paddle/fluid/operators/math/math_cuda_utils.h b/paddle/fluid/operators/math/math_cuda_utils.h index 17175fa7299d40938509b9243b5759278991e821..0325717b4d3714e8eae260beb89df7f2addda88f 100644 --- a/paddle/fluid/operators/math/math_cuda_utils.h +++ b/paddle/fluid/operators/math/math_cuda_utils.h @@ -26,9 +26,15 @@ __device__ __forceinline__ T FromFloat(float a); template __device__ __forceinline__ float ToFloat(T a); +template +__device__ __forceinline__ float2 ToFloat2(T a); + template __device__ __forceinline__ T exp_func(T a); +template +__device__ __forceinline__ T FloatsToPair(const float a, const float b); + template struct KeyValuePair; @@ -54,11 +60,35 @@ __device__ __forceinline__ float ToFloat(float a) { return a; } +template <> +__device__ __forceinline__ float2 ToFloat2(float2 a) { + return a; +} + +template <> +__device__ __forceinline__ float2 FloatsToPair(const float a, const float b) { + return make_float2(a, b); +} + +__inline__ __device__ float2 operator+(const float2 &a, const float2 &b) { + return make_float2(a.x + b.x, a.y + b.y); +} + #ifdef SUPPORTS_CUDA_FP16 template <> __device__ __forceinline__ float ToFloat(half a) { return __half2float(a); } + +template <> +__device__ __forceinline__ float2 ToFloat2<__half2>(__half2 a) { + return __half22float2(a); +} + +template <> +__device__ __forceinline__ __half2 FloatsToPair(const float a, const float b) { + return __floats2half2_rn(a, b); +} #endif template <> @@ -148,7 +178,7 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) { // align block_span to warpSize int block_span = (blockDim.x + warpSize - 1) >> 5; - val = (threadIdx.x < block_span) ? shared[lane] : static_cast(0.0f); + val = (lane < block_span) ? shared[lane] : static_cast(0.0f); val = warpReduceSum(val, mask); return val; @@ -180,7 +210,7 @@ __inline__ __device__ T blockReduceMax(T val, unsigned mask) { // align block_span to warpSize int block_span = (blockDim.x + warpSize - 1) >> 5; - val = (threadIdx.x < block_span) ? shared[lane] : -1e10f; + val = (lane < block_span) ? shared[lane] : -1e10f; val = warpReduceMax(val, mask); return val;