From d79eda716bfdb31bc1dc41ecf6aefb5330e50ef3 Mon Sep 17 00:00:00 2001 From: lzy <569782149@qq.com> Date: Mon, 21 Nov 2022 16:25:02 +0800 Subject: [PATCH] mma qk tensor_core (#48087) * use mma for QK dot computing in fused_multi_transformer. * Update fused_multi_transformer_op.cu.h --- .../fused/fused_multi_transformer_op.cu.h | 119 +++++++++++++++++- 1 file changed, 118 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index e6f4461f0c..777ee83c38 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -13,6 +13,8 @@ limitations under the License. */ // https://github.com/NVIDIA/FasterTransformer/blob/v4.0/fastertransformer/cuda/masked_multihead_attention.cu // We add License in the head. +#pragma once + #include #include @@ -88,6 +90,23 @@ using float16 = plat::float16; #define MMHA_USE_FP32_ACUM_FOR_LOGITS #define MMHA_USE_FP32_ACUM_FOR_OUT +#define MMHA_USE_FP32_ACUM_FOR_FMA +#define MMHA_USE_HMMA_FOR_REDUCTION + +template +class PDDataTypeTraits; + +template <> +class PDDataTypeTraits { + public: + typedef float DataType; +}; + +template <> +class PDDataTypeTraits { + public: + typedef half DataType; +}; template struct Masked_multihead_attention_params { @@ -150,6 +169,17 @@ template <> struct V_vec_ { using Type = uint32_t; }; template <> struct V_vec_ { using Type = uint2; }; template <> struct V_vec_ { using Type = uint4; }; +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA +template +struct K_vec_acum_fp32_ { +}; + +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +#endif + #ifdef MMHA_USE_FP32_ACUM_FOR_OUT template struct V_vec_acum_fp32_ {}; // template <> struct V_vec_acum_fp32_ { using Type = float; }; @@ -318,6 +348,15 @@ inline __device__ uint32_t mul(uint32_t a, float b) { return res; } +template <> +inline __device__ float2 mul(uint32_t a, float b) { + float2 tmp = half2_to_float2(a); + float2 res; + res.x = tmp.x * b; + res.y = tmp.y * b; + return res; +} + template <> inline __device__ uint2 mul(uint2 a, float b) { uint2 res; @@ -344,6 +383,15 @@ inline __device__ float2 mul(float2 a, float b) { return res; } +template <> +inline __device__ float2 mul(float2 a, uint32_t b) { + float2 tmp_b = half2_to_float2(b); + float2 res; + res.x = a.x * tmp_b.x; + res.y = a.y * tmp_b.y; + return res; +} + template <> inline __device__ float4 mul(float4 a, float b) { float4 res; @@ -403,6 +451,12 @@ inline __device__ float2 fma(float2 a, float2 b, float2 c) { return d; } +inline __device__ float2 fma(float2 a, uint32_t b, float2 c) { + float2 tmp_b = half2_to_float2(b); + float2 d = fma(a, tmp_b, c); + return d; +} + inline __device__ float4 fma(float4 a, float4 b, float4 c) { float4 d; d.x = fma(a.x, b.x, c.x); @@ -524,6 +578,49 @@ inline __device__ float qk_dot_(const K_vec (&q)[N], return qk; } +inline __device__ float4 hmma_fp32_tensorcore(const uint2 &a, uint32_t b) { + float4 c; + float zero = 0.f; + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%7, %7, %7, %7}; \n" + + : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) + : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); + return c; +} + +template +inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], + const uint32_t (&k)[N], + float inv_sqrt_dh) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = uint32_t; +#endif + K_vec_acum inv_q = mul(q[0], inv_sqrt_dh); + K_vec_acum qk_vec = mul(inv_q, k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + inv_q = mul(q[ii], inv_sqrt_dh); + qk_vec = fma(inv_q, k[ii], qk_vec); + } +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + uint32_t qk_vec_ = float2_to_half2(qk_vec); + return hmma_fp32_tensorcore(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; +#else + return hmma_fp32_tensorcore(make_uint2(qk_vec, 0u), 0x3c003c00u).x; +#endif +#else + return 0.f; +#endif +} + template struct Qk_dot { template @@ -534,6 +631,20 @@ struct Qk_dot { } }; +template <> +struct Qk_dot { + template + static inline __device__ float dot(const uint32_t (&q)[N], + const uint32_t (&k)[N], + float inv_sqrt_dh) { +#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && __CUDA_ARCH__ >= 750 + return qk_hmma_dot_(q, k, inv_sqrt_dh); +#else + return qk_dot_<4>(q, k, inv_sqrt_dh); +#endif + } +}; + template inline __device__ float block_sum(float *red_smem, float sum) { int warp = threadIdx.x / WARP_SIZE; @@ -606,6 +717,8 @@ template params) { #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + typedef PDDataTypeTraits traits_; + typedef typename traits_::DataType DataType_; static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); @@ -863,7 +976,7 @@ __global__ void masked_multihead_attention_kernel( float logit = logits_smem[ti]; out = fma(logit, cast_to_float(v), out); #else - T logit = logits_smem[ti]; + DataType_ logit = static_cast(logits_smem[ti]); // Update the partial sums. out = fma(logit, v, out); #endif @@ -987,7 +1100,11 @@ void fmha_launch_kernel(const Masked_multihead_attention_params ¶ms, if (params.timestep < 32) { MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream); } else if (params.timestep < 2048) { +#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && __CUDA_ARCH__ >= 750 + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 256, stream); +#else MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream); +#endif } else { MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream); } -- GitLab