未验证 提交 face8f1f 编写于 作者: R RichardWooSJTU 提交者: GitHub

fix multihead_matmul nan error when seq len et 1024 (#46286)

上级 23e06680
......@@ -378,7 +378,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf,
assert(blockDim.x % 32 == 0);
T stride_max = -1e20f;
for (int i = 0; i < seq_len; i += blockDim.x) {
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
stride_max = qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset] >
stride_max
......@@ -389,13 +389,13 @@ __global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf,
T max_val = phi::funcs::blockReduceMax<T>(stride_max, mask);
T stride_sum = 0.f;
for (int i = 0; i < seq_len; i += blockDim.x) {
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
stride_sum += __expf(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset] - max_val);
}
T sum_val = phi::funcs::blockReduceSum<T>(stride_sum, mask);
for (int i = 0; i < seq_len; i += blockDim.x) {
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
qk_buf[threadIdx.x + i + qk_offset] =
(T)(__expf(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset] - max_val) /
......@@ -417,7 +417,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf,
assert(blockDim.x % 32 == 0);
float stride_max = -1e20f;
for (int i = 0; i < seq_len; i += blockDim.x) {
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float tmp = static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset]);
stride_max = tmp > stride_max ? tmp : stride_max;
......@@ -425,14 +425,14 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf,
float max_val = phi::funcs::blockReduceMax<float>(stride_max, mask);
float stride_sum = 0.f;
for (int i = 0; i < seq_len; i += blockDim.x) {
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float tmp = static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset]);
stride_sum += __expf(tmp - max_val);
}
float sum_val = phi::funcs::blockReduceSum<float>(stride_sum, mask);
for (int i = 0; i < seq_len; i += blockDim.x) {
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float tmp =
__expf(static_cast<float>(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset]) -
......@@ -454,7 +454,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
assert(blockDim.x % 32 == 0);
float2 stride_max = make_float2(-1e20f, -1e20f);
for (int i = 0; i < seq_len; i += blockDim.x) {
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]);
stride_max.x = max(stride_max.x, cur.x);
......@@ -464,7 +464,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
phi::funcs::blockReduceMax<float>(max(stride_max.x, stride_max.y), mask);
float2 stride_sum = make_float2(0.f, 0.f);
for (int i = 0; i < seq_len; i += blockDim.x) {
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]);
stride_sum.x += __expf(cur.x - max_val);
......@@ -475,7 +475,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
phi::funcs::blockReduceSum<float>(stride_sum.x + stride_sum.y, mask) +
1e-6f;
for (int i = 0; i < seq_len; i += blockDim.x) {
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur = phi::funcs::ToFloat2<T>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]);
qk_buf_[threadIdx.x + i + qk_offset] = phi::funcs::FloatsToPair<T>(
......@@ -499,7 +499,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
assert(blockDim.x % 32 == 0);
float2 stride_max = make_float2(-1e20f, -1e20f);
for (int i = 0; i < seq_len; i += blockDim.x) {
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur =
phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]);
......@@ -510,7 +510,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
phi::funcs::blockReduceMax<float>(max(stride_max.x, stride_max.y), mask);
float2 stride_sum = make_float2(0.f, 0.f);
for (int i = 0; i < seq_len; i += blockDim.x) {
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur =
phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]);
......@@ -522,7 +522,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
phi::funcs::blockReduceSum<float>(stride_sum.x + stride_sum.y, mask) +
1e-6f;
for (int i = 0; i < seq_len; i += blockDim.x) {
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float2 cur =
phi::funcs::ToFloat2<half2>(qk_buf_[threadIdx.x + i + qk_offset] +
bias_qk_[threadIdx.x + i + qk_offset]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册