未验证 提交 bf4d1792 编写于 作者: L lzy 提交者: GitHub

fix mma_tensorcore (#48386)

* fix mma_tensorcore (__CUDA_ARCH__)

* disable tensorcore by default.

disable tensorcore by default, because the judgment of __CUDA_ARCH__ will cause undefined behavior in some environments, can manually enable it on a machine that supports tensorcore.
上级 2005d45a
...@@ -95,7 +95,7 @@ using float16 = plat::float16; ...@@ -95,7 +95,7 @@ using float16 = plat::float16;
#define MMHA_USE_FP32_ACUM_FOR_LOGITS #define MMHA_USE_FP32_ACUM_FOR_LOGITS
#define MMHA_USE_FP32_ACUM_FOR_OUT #define MMHA_USE_FP32_ACUM_FOR_OUT
#define MMHA_USE_FP32_ACUM_FOR_FMA #define MMHA_USE_FP32_ACUM_FOR_FMA
#define MMHA_USE_HMMA_FOR_REDUCTION // #define MMHA_USE_HMMA_FOR_REDUCTION
template <typename D> template <typename D>
class PDDataTypeTraits; class PDDataTypeTraits;
...@@ -601,7 +601,8 @@ template <int N> ...@@ -601,7 +601,8 @@ template <int N>
inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N],
const uint32_t (&k)[N], const uint32_t (&k)[N],
float inv_sqrt_dh) { float inv_sqrt_dh) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 #if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \
__CUDA_ARCH__ >= 750
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using K_vec_acum = typename K_vec_acum_fp32_<uint32_t>::Type; using K_vec_acum = typename K_vec_acum_fp32_<uint32_t>::Type;
#else #else
...@@ -641,7 +642,8 @@ struct Qk_dot<float16, 4> { ...@@ -641,7 +642,8 @@ struct Qk_dot<float16, 4> {
static inline __device__ float dot(const uint32_t (&q)[N], static inline __device__ float dot(const uint32_t (&q)[N],
const uint32_t (&k)[N], const uint32_t (&k)[N],
float inv_sqrt_dh) { float inv_sqrt_dh) {
#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && __CUDA_ARCH__ >= 750 #if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \
__CUDA_ARCH__ >= 750
return qk_hmma_dot_(q, k, inv_sqrt_dh); return qk_hmma_dot_(q, k, inv_sqrt_dh);
#else #else
return qk_dot_<4>(q, k, inv_sqrt_dh); return qk_dot_<4>(q, k, inv_sqrt_dh);
...@@ -1104,7 +1106,8 @@ void fmha_launch_kernel(const Masked_multihead_attention_params<T> &params, ...@@ -1104,7 +1106,8 @@ void fmha_launch_kernel(const Masked_multihead_attention_params<T> &params,
if (params.timestep < 32) { if (params.timestep < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream); MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream);
} else if (params.timestep < 2048) { } else if (params.timestep < 2048) {
#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && __CUDA_ARCH__ >= 750 #if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \
__CUDA_ARCH__ >= 750
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 256, stream); MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 256, stream);
#else #else
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream); MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册