未验证 提交 f4507974 编写于 作者: W WangXi 提交者: GitHub

fix fused_multi_transformer compile failed in cuda arch < sm53 (#42315)

上级 2e1fb26b
......@@ -534,6 +534,8 @@ template <typename T, int Dh, int THREADS_PER_KEY, int THREADS_PER_VALUE,
int THREADS_PER_BLOCK>
__global__ void masked_multihead_attention_kernel(
Masked_multihead_attention_params<T> params) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
static_assert(Dh % THREADS_PER_KEY == 0, "");
static_assert(Dh % THREADS_PER_VALUE == 0, "");
......@@ -821,6 +823,9 @@ __global__ void masked_multihead_attention_kernel(
printf("\n");
}
#endif
#else
assert(false);
#endif
}
template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册