diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index c36ee69723e45261f30efd2fe6bae3e719bbdd07..3c3a59b219615c17c412389e47f3b4ec19af599b 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -95,7 +95,7 @@ using float16 = plat::float16; #define MMHA_USE_FP32_ACUM_FOR_LOGITS #define MMHA_USE_FP32_ACUM_FOR_OUT #define MMHA_USE_FP32_ACUM_FOR_FMA -#define MMHA_USE_HMMA_FOR_REDUCTION +// #define MMHA_USE_HMMA_FOR_REDUCTION template class PDDataTypeTraits; @@ -601,7 +601,8 @@ template inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N], float inv_sqrt_dh) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 +#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 750 #ifdef MMHA_USE_FP32_ACUM_FOR_FMA using K_vec_acum = typename K_vec_acum_fp32_::Type; #else @@ -641,7 +642,8 @@ struct Qk_dot { static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N], float inv_sqrt_dh) { -#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && __CUDA_ARCH__ >= 750 +#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 750 return qk_hmma_dot_(q, k, inv_sqrt_dh); #else return qk_dot_<4>(q, k, inv_sqrt_dh); @@ -1104,7 +1106,8 @@ void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, if (params.timestep < 32) { MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream); } else if (params.timestep < 2048) { -#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && __CUDA_ARCH__ >= 750 +#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 750 MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 256, stream); #else MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream);