未验证 提交 08cada98 编写于 作者: R RichardWooSJTU 提交者: GitHub

fix build error in low arch (#44391)

上级 dd0a07f2
......@@ -38,10 +38,12 @@ __global__ void ElementwiseMask(const T* a,
const T* b,
T* res,
int num_elements) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= num_elements) return;
const T zero = 0;
res[tid] = b[tid] >= zero ? a[tid] : zero;
#endif
}
template <typename T>
......@@ -121,6 +123,7 @@ __global__ void ReduceSum2(
template <>
__global__ void ReduceSum2<half>(
const half* src, half* dst, int bsz, int nb_head, int max_seq_len) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int tid = threadIdx.x;
int bid = blockIdx.x;
int num_blocks_per_head = ((max_seq_len / blockDim.x) * max_seq_len);
......@@ -152,6 +155,7 @@ __global__ void ReduceSum2<half>(
static_cast<size_t>(bsz * max_seq_len),
static_cast<platform::float16>(res_half[0]));
}
#endif
}
template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册