未验证 提交 d79eda71 编写于 作者: L lzy 提交者: GitHub

mma qk tensor_core (#48087)

* use mma for QK dot computing in fused_multi_transformer.
* Update fused_multi_transformer_op.cu.h
上级 56f15c43
......@@ -13,6 +13,8 @@ limitations under the License. */
// https://github.com/NVIDIA/FasterTransformer/blob/v4.0/fastertransformer/cuda/masked_multihead_attention.cu
// We add License in the head.
#pragma once
#include <cuda_fp16.h>
#include <float.h>
......@@ -88,6 +90,23 @@ using float16 = plat::float16;
#define MMHA_USE_FP32_ACUM_FOR_LOGITS
#define MMHA_USE_FP32_ACUM_FOR_OUT
#define MMHA_USE_FP32_ACUM_FOR_FMA
#define MMHA_USE_HMMA_FOR_REDUCTION
template <typename D>
class PDDataTypeTraits;
template <>
class PDDataTypeTraits<float> {
public:
typedef float DataType;
};
template <>
class PDDataTypeTraits<float16> {
public:
typedef half DataType;
};
template <typename T>
struct Masked_multihead_attention_params {
......@@ -150,6 +169,17 @@ template <> struct V_vec_<float16, 2> { using Type = uint32_t; };
template <> struct V_vec_<float16, 4> { using Type = uint2; };
template <> struct V_vec_<float16, 8> { using Type = uint4; };
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
template<typename T>
struct K_vec_acum_fp32_ {
};
template<>
struct K_vec_acum_fp32_<uint32_t> {
using Type = float2;
};
#endif
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
template <typename T> struct V_vec_acum_fp32_ {};
// template <> struct V_vec_acum_fp32_<float> { using Type = float; };
......@@ -318,6 +348,15 @@ inline __device__ uint32_t mul(uint32_t a, float b) {
return res;
}
template <>
inline __device__ float2 mul(uint32_t a, float b) {
float2 tmp = half2_to_float2(a);
float2 res;
res.x = tmp.x * b;
res.y = tmp.y * b;
return res;
}
template <>
inline __device__ uint2 mul(uint2 a, float b) {
uint2 res;
......@@ -344,6 +383,15 @@ inline __device__ float2 mul(float2 a, float b) {
return res;
}
template <>
inline __device__ float2 mul(float2 a, uint32_t b) {
float2 tmp_b = half2_to_float2(b);
float2 res;
res.x = a.x * tmp_b.x;
res.y = a.y * tmp_b.y;
return res;
}
template <>
inline __device__ float4 mul(float4 a, float b) {
float4 res;
......@@ -403,6 +451,12 @@ inline __device__ float2 fma(float2 a, float2 b, float2 c) {
return d;
}
inline __device__ float2 fma(float2 a, uint32_t b, float2 c) {
float2 tmp_b = half2_to_float2(b);
float2 d = fma(a, tmp_b, c);
return d;
}
inline __device__ float4 fma(float4 a, float4 b, float4 c) {
float4 d;
d.x = fma(a.x, b.x, c.x);
......@@ -524,6 +578,49 @@ inline __device__ float qk_dot_(const K_vec (&q)[N],
return qk;
}
inline __device__ float4 hmma_fp32_tensorcore(const uint2 &a, uint32_t b) {
float4 c;
float zero = 0.f;
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
" {%0, %1, %2, %3}, \n"
" {%4, %5}, \n"
" {%6}, \n"
" {%7, %7, %7, %7}; \n"
: "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
: "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
return c;
}
template <int N>
inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N],
const uint32_t (&k)[N],
float inv_sqrt_dh) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using K_vec_acum = typename K_vec_acum_fp32_<uint32_t>::Type;
#else
using K_vec_acum = uint32_t;
#endif
K_vec_acum inv_q = mul<K_vec_acum, uint32_t, float>(q[0], inv_sqrt_dh);
K_vec_acum qk_vec = mul<K_vec_acum, K_vec_acum, uint32_t>(inv_q, k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
inv_q = mul<K_vec_acum, uint32_t, float>(q[ii], inv_sqrt_dh);
qk_vec = fma(inv_q, k[ii], qk_vec);
}
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
uint32_t qk_vec_ = float2_to_half2(qk_vec);
return hmma_fp32_tensorcore(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
#else
return hmma_fp32_tensorcore(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
#endif
#else
return 0.f;
#endif
}
template <typename T, int THREADS_PER_KEY>
struct Qk_dot {
template <typename K_vec, int N>
......@@ -534,6 +631,20 @@ struct Qk_dot {
}
};
template <>
struct Qk_dot<float16, 4> {
template <int N>
static inline __device__ float dot(const uint32_t (&q)[N],
const uint32_t (&k)[N],
float inv_sqrt_dh) {
#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && __CUDA_ARCH__ >= 750
return qk_hmma_dot_(q, k, inv_sqrt_dh);
#else
return qk_dot_<4>(q, k, inv_sqrt_dh);
#endif
}
};
template <int WARPS_PER_BLOCK, int WARP_SIZE = 32>
inline __device__ float block_sum(float *red_smem, float sum) {
int warp = threadIdx.x / WARP_SIZE;
......@@ -606,6 +717,8 @@ template <typename T,
__global__ void masked_multihead_attention_kernel(
Masked_multihead_attention_params<T> params) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
typedef PDDataTypeTraits<T> traits_;
typedef typename traits_::DataType DataType_;
static_assert(Dh_MAX % THREADS_PER_KEY == 0, "");
static_assert(Dh_MAX % THREADS_PER_VALUE == 0, "");
......@@ -863,7 +976,7 @@ __global__ void masked_multihead_attention_kernel(
float logit = logits_smem[ti];
out = fma(logit, cast_to_float(v), out);
#else
T logit = logits_smem[ti];
DataType_ logit = static_cast<DataType_>(logits_smem[ti]);
// Update the partial sums.
out = fma(logit, v, out);
#endif
......@@ -987,7 +1100,11 @@ void fmha_launch_kernel(const Masked_multihead_attention_params<T> &params,
if (params.timestep < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, stream);
} else if (params.timestep < 2048) {
#if defined(MMHA_USE_HMMA_FOR_REDUCTION) && __CUDA_ARCH__ >= 750
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 256, stream);
#else
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, stream);
#endif
} else {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, stream);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册