未验证 提交 04511cf9 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

fix_moe (#49353)

上级 a221158f
......@@ -318,6 +318,7 @@ __global__ void softmax_kernel_v4(
const int seq_len_1,
const int seq_len_2,
const T scalar) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
for (int seq_id = blockIdx.x; seq_id < seq_len_1; seq_id += gridDim.x) {
float data[ITEMS_PER_THREAD];
int qk_offset;
......@@ -370,6 +371,7 @@ __global__ void softmax_kernel_v4(
qk_buf_[qk_offset] = (T)(data[i] * s_mean);
}
}
#endif
}
template <typename T, int ITEMS_PER_THREAD>
......@@ -380,6 +382,7 @@ __global__ void softmax_kernel_v4_half2(T* qk_buf_,
const int seq_len_1,
const int seq_len_2,
const T scalar) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
using T2 = half2;
T2* qk_buf_half2 = reinterpret_cast<T2*>(qk_buf_);
const T2* attr_mask_half2 = (const T2*)attr_mask;
......@@ -447,6 +450,7 @@ __global__ void softmax_kernel_v4_half2(T* qk_buf_,
qk_buf_half2[qk_offset] = __hmul2(data[i], __float2half2_rn(s_mean));
}
}
#endif
}
template <typename T, int ITEMS_PER_THREAD, int NUM>
......@@ -457,6 +461,7 @@ __global__ void softmax_kernel_v5_half2(T* qk_buf_,
const int seq_len_1,
const int seq_len_2,
const T scalar) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
using T2 = half2;
T2* qk_buf_half2 = reinterpret_cast<T2*>(qk_buf_);
const T2* attr_mask_half2 = (const T2*)attr_mask;
......@@ -579,6 +584,7 @@ __global__ void softmax_kernel_v5_half2(T* qk_buf_,
}
}
}
#endif
}
// -------- transpose_kernel -------- //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册