未验证 提交 04511cf9 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

fix_moe (#49353)

上级 a221158f
...@@ -318,6 +318,7 @@ __global__ void softmax_kernel_v4( ...@@ -318,6 +318,7 @@ __global__ void softmax_kernel_v4(
const int seq_len_1, const int seq_len_1,
const int seq_len_2, const int seq_len_2,
const T scalar) { const T scalar) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
for (int seq_id = blockIdx.x; seq_id < seq_len_1; seq_id += gridDim.x) { for (int seq_id = blockIdx.x; seq_id < seq_len_1; seq_id += gridDim.x) {
float data[ITEMS_PER_THREAD]; float data[ITEMS_PER_THREAD];
int qk_offset; int qk_offset;
...@@ -370,6 +371,7 @@ __global__ void softmax_kernel_v4( ...@@ -370,6 +371,7 @@ __global__ void softmax_kernel_v4(
qk_buf_[qk_offset] = (T)(data[i] * s_mean); qk_buf_[qk_offset] = (T)(data[i] * s_mean);
} }
} }
#endif
} }
template <typename T, int ITEMS_PER_THREAD> template <typename T, int ITEMS_PER_THREAD>
...@@ -380,6 +382,7 @@ __global__ void softmax_kernel_v4_half2(T* qk_buf_, ...@@ -380,6 +382,7 @@ __global__ void softmax_kernel_v4_half2(T* qk_buf_,
const int seq_len_1, const int seq_len_1,
const int seq_len_2, const int seq_len_2,
const T scalar) { const T scalar) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
using T2 = half2; using T2 = half2;
T2* qk_buf_half2 = reinterpret_cast<T2*>(qk_buf_); T2* qk_buf_half2 = reinterpret_cast<T2*>(qk_buf_);
const T2* attr_mask_half2 = (const T2*)attr_mask; const T2* attr_mask_half2 = (const T2*)attr_mask;
...@@ -447,6 +450,7 @@ __global__ void softmax_kernel_v4_half2(T* qk_buf_, ...@@ -447,6 +450,7 @@ __global__ void softmax_kernel_v4_half2(T* qk_buf_,
qk_buf_half2[qk_offset] = __hmul2(data[i], __float2half2_rn(s_mean)); qk_buf_half2[qk_offset] = __hmul2(data[i], __float2half2_rn(s_mean));
} }
} }
#endif
} }
template <typename T, int ITEMS_PER_THREAD, int NUM> template <typename T, int ITEMS_PER_THREAD, int NUM>
...@@ -457,6 +461,7 @@ __global__ void softmax_kernel_v5_half2(T* qk_buf_, ...@@ -457,6 +461,7 @@ __global__ void softmax_kernel_v5_half2(T* qk_buf_,
const int seq_len_1, const int seq_len_1,
const int seq_len_2, const int seq_len_2,
const T scalar) { const T scalar) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
using T2 = half2; using T2 = half2;
T2* qk_buf_half2 = reinterpret_cast<T2*>(qk_buf_); T2* qk_buf_half2 = reinterpret_cast<T2*>(qk_buf_);
const T2* attr_mask_half2 = (const T2*)attr_mask; const T2* attr_mask_half2 = (const T2*)attr_mask;
...@@ -579,6 +584,7 @@ __global__ void softmax_kernel_v5_half2(T* qk_buf_, ...@@ -579,6 +584,7 @@ __global__ void softmax_kernel_v5_half2(T* qk_buf_,
} }
} }
} }
#endif
} }
// -------- transpose_kernel -------- // // -------- transpose_kernel -------- //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册