未验证 提交 479c8834 编写于 作者: Z zlsh80826 提交者: GitHub

[Paddle-TRT] Fixes #24731, opt for SoftmaxKernelWithEltadd kernel, test=develop (#24834)

* blockReduce opt

* launch threads align to warpSize

* reduce unnecessary shared memory for broadcast reduced value

* vectorize SoftmaxKernelWithEltadd

* add fp16 constrain

* test=develop
上级 2c500c30
......@@ -143,30 +143,42 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_, const T *bias_qk_,
int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % 32 == 0);
__shared__ float s_sum, s_max;
float qk = threadIdx.x < seq_len
? static_cast<float>((qk_buf_[threadIdx.x + qk_offset] +
bias_qk_[threadIdx.x + qk_offset]))
: 0.0f;
float tmp = threadIdx.x < seq_len ? static_cast<float>(qk) : -1e20f;
float tmp = threadIdx.x < seq_len
? static_cast<float>(qk_buf_[threadIdx.x + qk_offset] +
bias_qk_[threadIdx.x + qk_offset])
: -1e20f;
float max_val = blockReduceMax<float>(tmp, mask);
if (threadIdx.x == 0) s_max = max_val;
__syncthreads();
float qk_tmp =
threadIdx.x < seq_len ? __expf(static_cast<float>(tmp - s_max)) : 0.0f;
float qk_tmp = threadIdx.x < seq_len ? __expf(tmp - max_val) : 0.0f;
float sum_val = blockReduceSum<float>(qk_tmp, mask);
if (threadIdx.x == 0) {
s_sum = sum_val + 1e-6f;
}
__syncthreads();
if (threadIdx.x < seq_len)
qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / s_sum);
qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / sum_val);
}
template <typename T>
__global__ void SoftmaxKernelWithEltadd2(T *qk_buf_, const T *bias_qk_,
const int batch_size,
const int head_num, const int seq_len,
const unsigned mask) {
int qk_offset = blockIdx.x * seq_len;
int idx = threadIdx.x;
assert(blockDim.x % 32 == 0);
float2 tmp =
idx < seq_len
? ToFloat2<T>(qk_buf_[idx + qk_offset] + bias_qk_[idx + qk_offset])
: make_float2(-1e20f, -1e20f);
float max_val = blockReduceMax<float>(max(tmp.x, tmp.y), mask);
float2 qk_tmp = idx < seq_len ? make_float2(__expf(tmp.x - max_val),
__expf(tmp.y - max_val))
: make_float2(0.f, 0.f);
float sum_val = blockReduceSum<float>(qk_tmp.x + qk_tmp.y, mask) + 1e-6f;
if (idx < seq_len) {
qk_buf_[idx + qk_offset] =
FloatsToPair<T>(qk_tmp.x / sum_val, qk_tmp.y / sum_val);
}
}
template <typename T>
......@@ -198,21 +210,28 @@ inline void MatMulWithHeadQK(const platform::CUDADeviceContext &context,
"seq_len should <= 1024, "
"but received seq_len is:%d",
seq_len));
if (seq_len <= 32)
block = 32;
else if (seq_len > 32 && seq_len <= 64)
block = 64;
else if (seq_len > 64 && seq_len <= 128)
block = 128;
else if (seq_len > 128 && seq_len <= 256)
block = 256;
else if (seq_len > 256 && seq_len <= 512)
block = 512;
else
block = 1024;
if (seq_len % 2 == 0) {
block = (seq_len <= 64) ? 32 : ((seq_len + 63) / 64) * 32;
#ifdef SUPPORTS_CUDA_FP16
if (std::is_same<T, float>::value) {
#endif
SoftmaxKernelWithEltadd2<float2><<<grid, block, 0, stream>>>(
reinterpret_cast<float2 *>(qk_buf_),
reinterpret_cast<const float2 *>(bias_qk), batch_size, head_num,
seq_len / 2, FINAL_MASK);
#ifdef SUPPORTS_CUDA_FP16
} else {
SoftmaxKernelWithEltadd2<__half2><<<grid, block, 0, stream>>>(
reinterpret_cast<__half2 *>(qk_buf_),
reinterpret_cast<const __half2 *>(bias_qk), batch_size, head_num,
seq_len / 2, FINAL_MASK);
}
#endif
} else {
block = (seq_len <= 32) ? 32 : ((seq_len + 31) / 32) * 32;
SoftmaxKernelWithEltadd<T><<<grid, block, 0, stream>>>(
qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK);
}
}
template <typename T>
......
......@@ -26,9 +26,15 @@ __device__ __forceinline__ T FromFloat(float a);
template <typename T>
__device__ __forceinline__ float ToFloat(T a);
template <typename T>
__device__ __forceinline__ float2 ToFloat2(T a);
template <typename T>
__device__ __forceinline__ T exp_func(T a);
template <typename T>
__device__ __forceinline__ T FloatsToPair(const float a, const float b);
template <typename T>
struct KeyValuePair;
......@@ -54,11 +60,35 @@ __device__ __forceinline__ float ToFloat<float>(float a) {
return a;
}
template <>
__device__ __forceinline__ float2 ToFloat2<float2>(float2 a) {
return a;
}
template <>
__device__ __forceinline__ float2 FloatsToPair(const float a, const float b) {
return make_float2(a, b);
}
__inline__ __device__ float2 operator+(const float2 &a, const float2 &b) {
return make_float2(a.x + b.x, a.y + b.y);
}
#ifdef SUPPORTS_CUDA_FP16
template <>
__device__ __forceinline__ float ToFloat<half>(half a) {
return __half2float(a);
}
template <>
__device__ __forceinline__ float2 ToFloat2<__half2>(__half2 a) {
return __half22float2(a);
}
template <>
__device__ __forceinline__ __half2 FloatsToPair(const float a, const float b) {
return __floats2half2_rn(a, b);
}
#endif
template <>
......@@ -148,7 +178,7 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) {
// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (threadIdx.x < block_span) ? shared[lane] : static_cast<T>(0.0f);
val = (lane < block_span) ? shared[lane] : static_cast<T>(0.0f);
val = warpReduceSum<T>(val, mask);
return val;
......@@ -180,7 +210,7 @@ __inline__ __device__ T blockReduceMax(T val, unsigned mask) {
// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (threadIdx.x < block_span) ? shared[lane] : -1e10f;
val = (lane < block_span) ? shared[lane] : -1e10f;
val = warpReduceMax(val, mask);
return val;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册