未验证 提交 face8f1f 编写于 作者: R RichardWooSJTU 提交者: GitHub

fix multihead_matmul nan error when seq len et 1024 (#46286)

上级 23e06680
...@@ -378,7 +378,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf, ...@@ -378,7 +378,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf,
assert(blockDim.x % 32 == 0); assert(blockDim.x % 32 == 0);
T stride_max = -1e20f; T stride_max = -1e20f;
for (int i = 0; i < seq_len; i += blockDim.x) { for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
stride_max = qk_buf[threadIdx.x + i + qk_offset] + stride_max = qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset] > bias_qk[threadIdx.x + i + qk_offset] >
stride_max stride_max
...@@ -389,13 +389,13 @@ __global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf, ...@@ -389,13 +389,13 @@ __global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf,
T max_val = phi::funcs::blockReduceMax<T>(stride_max, mask); T max_val = phi::funcs::blockReduceMax<T>(stride_max, mask);
T stride_sum = 0.f; T stride_sum = 0.f;
for (int i = 0; i < seq_len; i += blockDim.x) { for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
stride_sum += __expf(qk_buf[threadIdx.x + i + qk_offset] + stride_sum += __expf(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset] - max_val); bias_qk[threadIdx.x + i + qk_offset] - max_val);
} }
T sum_val = phi::funcs::blockReduceSum<T>(stride_sum, mask); T sum_val = phi::funcs::blockReduceSum<T>(stride_sum, mask);
for (int i = 0; i < seq_len; i += blockDim.x) { for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
qk_buf[threadIdx.x + i + qk_offset] = qk_buf[threadIdx.x + i + qk_offset] =
(T)(__expf(qk_buf[threadIdx.x + i + qk_offset] + (T)(__expf(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset] - max_val) / bias_qk[threadIdx.x + i + qk_offset] - max_val) /
...@@ -417,7 +417,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf, ...@@ -417,7 +417,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf,
assert(blockDim.x % 32 == 0); assert(blockDim.x % 32 == 0);
float stride_max = -1e20f; float stride_max = -1e20f;
for (int i = 0; i < seq_len; i += blockDim.x) { for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float tmp = static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] + float tmp = static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset]); bias_qk[threadIdx.x + i + qk_offset]);
stride_max = tmp > stride_max ? tmp : stride_max; stride_max = tmp > stride_max ? tmp : stride_max;
...@@ -425,14 +425,14 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf, ...@@ -425,14 +425,14 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf,
float max_val = phi::funcs::blockReduceMax<float>(stride_max, mask); float max_val = phi::funcs::blockReduceMax<float>(stride_max, mask);
float stride_sum = 0.f; float stride_sum = 0.f;
for (int i = 0; i < seq_len; i += blockDim.x) { for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float tmp = static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] + float tmp = static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset]); bias_qk[threadIdx.x + i + qk_offset]);
stride_sum += __expf(tmp - max_val); stride_sum += __expf(tmp - max_val);
} }
float sum_val = phi::funcs::blockReduceSum<float>(stride_sum, mask); float sum_val = phi::funcs::blockReduceSum<float>(stride_sum, mask);
for (int i = 0; i < seq_len; i += blockDim.x) { for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float tmp = float tmp =
__expf(static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] + __expf(static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset]) - bias_qk[threadIdx.x + i + qk_offset]) -
...@@ -454,7 +454,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_, ...@@ -454,7 +454,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
assert(blockDim.x % 32 == 0); assert(blockDim.x % 32 == 0);
float2 stride_max = make_float2(-1e20f, -1e20f); float2 stride_max = make_float2(-1e20f, -1e20f);
for (int i = 0; i < seq_len; i += blockDim.x) { for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] + float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]); bias_qk_[threadIdx.x + i + qk_offset]);
stride_max.x = max(stride_max.x, cur.x); stride_max.x = max(stride_max.x, cur.x);
...@@ -464,7 +464,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_, ...@@ -464,7 +464,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
phi::funcs::blockReduceMax<float>(max(stride_max.x, stride_max.y), mask); phi::funcs::blockReduceMax<float>(max(stride_max.x, stride_max.y), mask);
float2 stride_sum = make_float2(0.f, 0.f); float2 stride_sum = make_float2(0.f, 0.f);
for (int i = 0; i < seq_len; i += blockDim.x) { for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] + float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]); bias_qk_[threadIdx.x + i + qk_offset]);
stride_sum.x += __expf(cur.x - max_val); stride_sum.x += __expf(cur.x - max_val);
...@@ -475,7 +475,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_, ...@@ -475,7 +475,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
phi::funcs::blockReduceSum<float>(stride_sum.x + stride_sum.y, mask) + phi::funcs::blockReduceSum<float>(stride_sum.x + stride_sum.y, mask) +
1e-6f; 1e-6f;
for (int i = 0; i < seq_len; i += blockDim.x) { for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] + float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]); bias_qk_[threadIdx.x + i + qk_offset]);
qk_buf_[threadIdx.x + i + qk_offset] = phi::funcs::FloatsToPair<T>( qk_buf_[threadIdx.x + i + qk_offset] = phi::funcs::FloatsToPair<T>(
...@@ -499,7 +499,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_, ...@@ -499,7 +499,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
assert(blockDim.x % 32 == 0); assert(blockDim.x % 32 == 0);
float2 stride_max = make_float2(-1e20f, -1e20f); float2 stride_max = make_float2(-1e20f, -1e20f);
for (int i = 0; i < seq_len; i += blockDim.x) { for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur = float2 cur =
phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] + phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]); bias_qk_[threadIdx.x + i + qk_offset]);
...@@ -510,7 +510,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_, ...@@ -510,7 +510,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
phi::funcs::blockReduceMax<float>(max(stride_max.x, stride_max.y), mask); phi::funcs::blockReduceMax<float>(max(stride_max.x, stride_max.y), mask);
float2 stride_sum = make_float2(0.f, 0.f); float2 stride_sum = make_float2(0.f, 0.f);
for (int i = 0; i < seq_len; i += blockDim.x) { for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur = float2 cur =
phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] + phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]); bias_qk_[threadIdx.x + i + qk_offset]);
...@@ -522,7 +522,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_, ...@@ -522,7 +522,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
phi::funcs::blockReduceSum<float>(stride_sum.x + stride_sum.y, mask) + phi::funcs::blockReduceSum<float>(stride_sum.x + stride_sum.y, mask) +
1e-6f; 1e-6f;
for (int i = 0; i < seq_len; i += blockDim.x) { for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur = float2 cur =
phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] + phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]); bias_qk_[threadIdx.x + i + qk_offset]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册