提交 336c03bf 编写于 作者: V Vincent Huang 提交者: Rajeev Rao

Integrate mha plugin to handle head 1~32 and sequence 512

Signed-off-by: NRajeev Rao <rajeevrao@nvidia.com>
上级 221fd923
......@@ -68,11 +68,15 @@ documentation.
## Changelog
October 2020
Add v2 plugin that supports variable sequence length.
September 2021
Add sequence length 512 support in v2 plugin
Add head size 32 support when sequence length is 128, 256 or 512 in v2 plugin
October 2020
Add v2 plugin that supports variable sequence length.
Add v3 plugin that supports int8 interleaved variable sequence length.
November 2019
November 2019
This is the first release of this `README.md` file.
......
......@@ -213,17 +213,14 @@ public:
{
}
void loadXMMAKernels()
void loadXMMAKernels(uint32_t smVersion)
{
if (!mFunctions.empty())
{
return;
}
for (uint32_t i = 0; i < mKernelMetaCount; ++i)
{
const auto& kernelMeta = mKernelMeta[i];
if (kernelMeta.mSM == mSM && kernelMeta.mDataType == mDataType)
const auto kernelKey = hashID(kernelMeta);
if (kernelMeta.mSM == smVersion && kernelMeta.mDataType == mDataType
&& mFunctions.find(kernelKey) == mFunctions.end())
{
CUmodule hmod{0};
auto findModuleIter = mModules.find(kernelMeta.mCubin);
......@@ -240,20 +237,43 @@ public:
FusedMultiHeadAttentionKernelInfo funcInfo;
funcInfo.mMetaInfoIndex = i;
cuErrCheck(mDriver.cuModuleGetFunction(&funcInfo.mDeviceFunction, hmod, kernelMeta.mFuncName), mDriver);
if (kernelMeta.mSharedMemBytes >= 48 * 1024)
const uint32_t DEFAULT_SMEM_SIZE{48 * 1024};
if (kernelMeta.mSharedMemBytes >= DEFAULT_SMEM_SIZE)
{
cuErrCheck(mDriver.cuFuncSetAttribute(funcInfo.mDeviceFunction,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, kernelMeta.mSharedMemBytes),
mDriver);
if (mDriver.cuFuncSetAttribute(funcInfo.mDeviceFunction,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, kernelMeta.mSharedMemBytes)
!= CUDA_SUCCESS)
{
// some chip may not have enough shared memory to launch the kernel
continue;
}
}
mFunctions.insert(std::make_pair(hashID(kernelMeta), funcInfo));
int s = static_cast<int>(kernelMeta.mS);
mFunctions.insert({kernelKey, funcInfo});
const int s = static_cast<int>(kernelMeta.mS);
if (mValidSequences.find(s) == mValidSequences.end())
{
mValidSequences.insert(s);
}
}
}
}
void loadXMMAKernels()
{
if (!mFunctions.empty())
{
return;
}
loadXMMAKernels(mSM);
// sm_86 chips prefer sm_86 sass, but can also use sm_80 sass if sm_86 not exist.
if (mSM != kSM_80 && mSM / 10U == 8)
{
loadXMMAKernels(kSM_80);
}
}
bool isValid(int s) const
{
return (mValidSequences.find(s) != mValidSequences.end());
......
......@@ -141,6 +141,23 @@ extern unsigned char fused_multihead_attention_v2_int8_384_64_kernel_sm75_cubin[
extern unsigned char fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin[];
extern unsigned char fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin[];
extern unsigned char cubin_fmha_v2_int8_512_64_sm80_cu_cubin[];
extern unsigned char cubin_fmha_v2_int8_512_32_sm80_cu_cubin[];
extern unsigned char cubin_fmha_v2_int8_256_32_sm80_cu_cubin[];
extern unsigned char cubin_fmha_v2_int8_128_32_sm80_cu_cubin[];
extern unsigned char cubin_fmha_v2_fp16_512_64_sm80_cu_cubin[];
extern unsigned char cubin_fmha_v2_fp16_512_32_sm80_cu_cubin[];
extern unsigned char cubin_fmha_v2_fp16_256_32_sm80_cu_cubin[];
extern unsigned char cubin_fmha_v2_fp16_128_32_sm80_cu_cubin[];
extern unsigned char cubin_fmha_v2_int8_512_64_sm75_cu_cubin[];
extern unsigned char cubin_fmha_v2_int8_512_32_sm75_cu_cubin[];
extern unsigned char cubin_fmha_v2_int8_256_32_sm75_cu_cubin[];
extern unsigned char cubin_fmha_v2_int8_128_32_sm75_cu_cubin[];
extern unsigned char cubin_fmha_v2_fp16_512_64_sm75_cu_cubin[];
extern unsigned char cubin_fmha_v2_fp16_512_32_sm75_cu_cubin[];
extern unsigned char cubin_fmha_v2_fp16_256_32_sm75_cu_cubin[];
extern unsigned char cubin_fmha_v2_fp16_128_32_sm75_cu_cubin[];
extern uint32_t fused_multihead_attention_v2_fp16_128_64_kernel_sm75_cubin_len;
extern uint32_t fused_multihead_attention_v2_fp16_128_64_kernel_sm80_cubin_len;
extern uint32_t fused_multihead_attention_v2_fp16_128_64_kernel_sm86_cubin_len;
......@@ -173,6 +190,23 @@ extern uint32_t fused_multihead_attention_v2_int8_384_64_kernel_sm75_cubin_len;
extern uint32_t fused_multihead_attention_v2_int8_384_64_kernel_sm80_cubin_len;
extern uint32_t fused_multihead_attention_v2_int8_384_64_kernel_sm86_cubin_len;
extern uint32_t cubin_fmha_v2_int8_512_64_sm80_cu_cubin_len;
extern uint32_t cubin_fmha_v2_int8_512_32_sm80_cu_cubin_len;
extern uint32_t cubin_fmha_v2_int8_256_32_sm80_cu_cubin_len;
extern uint32_t cubin_fmha_v2_int8_128_32_sm80_cu_cubin_len;
extern uint32_t cubin_fmha_v2_fp16_512_64_sm80_cu_cubin_len;
extern uint32_t cubin_fmha_v2_fp16_512_32_sm80_cu_cubin_len;
extern uint32_t cubin_fmha_v2_fp16_256_32_sm80_cu_cubin_len;
extern uint32_t cubin_fmha_v2_fp16_128_32_sm80_cu_cubin_len;
extern uint32_t cubin_fmha_v2_int8_512_64_sm75_cu_cubin_len;
extern uint32_t cubin_fmha_v2_int8_512_32_sm75_cu_cubin_len;
extern uint32_t cubin_fmha_v2_int8_256_32_sm75_cu_cubin_len;
extern uint32_t cubin_fmha_v2_int8_128_32_sm75_cu_cubin_len;
extern uint32_t cubin_fmha_v2_fp16_512_64_sm75_cu_cubin_len;
extern uint32_t cubin_fmha_v2_fp16_512_32_sm75_cu_cubin_len;
extern uint32_t cubin_fmha_v2_fp16_256_32_sm75_cu_cubin_len;
extern uint32_t cubin_fmha_v2_fp16_128_32_sm75_cu_cubin_len;
static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{
Data_type mDataType;
......@@ -288,7 +322,93 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
fused_multihead_attention_v2_int8_384_64_kernel_sm75_cubin_len,
"fused_multihead_attention_v2_int8_384_64_kernel_sm75", 51200, 128, 0, false},
// {fp16} x {512} x {64, 32} x {sm75} x {normal, noloop}
{DATA_TYPE_FP16, 512, 64, kSM_75, cubin_fmha_v2_fp16_512_64_sm75_cu_cubin,
cubin_fmha_v2_fp16_512_64_sm75_cu_cubin_len, "fmha_v2_fp16_512_64_sm75_kernel", 69632, 256, 0, false},
{DATA_TYPE_FP16, 512, 64, kSM_75, cubin_fmha_v2_fp16_512_64_sm75_cu_cubin,
cubin_fmha_v2_fp16_512_64_sm75_cu_cubin_len, "fmha_v2_fp16_512_64_sm75_kernel_nl", 69632, 256, 32, false},
{DATA_TYPE_FP16, 512, 32, kSM_75, cubin_fmha_v2_fp16_512_32_sm75_cu_cubin,
cubin_fmha_v2_fp16_512_32_sm75_cu_cubin_len, "fmha_v2_fp16_512_32_sm75_kernel", 36864, 256, 0, false},
{DATA_TYPE_FP16, 512, 32, kSM_75, cubin_fmha_v2_fp16_512_32_sm75_cu_cubin,
cubin_fmha_v2_fp16_512_32_sm75_cu_cubin_len, "fmha_v2_fp16_512_32_sm75_kernel_nl", 36864, 256, 32, false},
// {fp16, int8} x {128} x {32} x {sm75} x {normal, noloop}
{DATA_TYPE_INT8, 128, 32, kSM_75, cubin_fmha_v2_int8_128_32_sm75_cu_cubin,
cubin_fmha_v2_int8_128_32_sm75_cu_cubin_len, "fmha_v2_int8_128_32_sm75_kernel", 12288, 128, 0, false},
{DATA_TYPE_INT8, 128, 32, kSM_75, cubin_fmha_v2_int8_128_32_sm75_cu_cubin,
cubin_fmha_v2_int8_128_32_sm75_cu_cubin_len, "fmha_v2_int8_128_32_sm75_kernel_nl", 10240, 128, 32, false},
{DATA_TYPE_FP16, 128, 32, kSM_75, cubin_fmha_v2_fp16_128_32_sm75_cu_cubin,
cubin_fmha_v2_fp16_128_32_sm75_cu_cubin_len, "fmha_v2_fp16_128_32_sm75_kernel", 16384, 128, 0, false},
{DATA_TYPE_FP16, 128, 32, kSM_75, cubin_fmha_v2_fp16_128_32_sm75_cu_cubin,
cubin_fmha_v2_fp16_128_32_sm75_cu_cubin_len, "fmha_v2_fp16_128_32_sm75_kernel_nl", 10240, 128, 32, false},
// {fp16, int8} x {256} x {32} x {sm75} x {normal, noloop}
{DATA_TYPE_INT8, 256, 32, kSM_75, cubin_fmha_v2_int8_256_32_sm75_cu_cubin,
cubin_fmha_v2_int8_256_32_sm75_cu_cubin_len, "fmha_v2_int8_256_32_sm75_kernel", 18432, 128, 0, false},
{DATA_TYPE_INT8, 256, 32, kSM_75, cubin_fmha_v2_int8_256_32_sm75_cu_cubin,
cubin_fmha_v2_int8_256_32_sm75_cu_cubin_len, "fmha_v2_int8_256_32_sm75_kernel_nl", 18432, 128, 32, false},
{DATA_TYPE_FP16, 256, 32, kSM_75, cubin_fmha_v2_fp16_256_32_sm75_cu_cubin,
cubin_fmha_v2_fp16_256_32_sm75_cu_cubin_len, "fmha_v2_fp16_256_32_sm75_kernel", 18432, 128, 0, false},
{DATA_TYPE_FP16, 256, 32, kSM_75, cubin_fmha_v2_fp16_256_32_sm75_cu_cubin,
cubin_fmha_v2_fp16_256_32_sm75_cu_cubin_len, "fmha_v2_fp16_256_32_sm75_kernel_nl", 18432, 128, 32, false},
// {int8} x {512} x {64, 32} x {sm75} x {normal, noloop}
{DATA_TYPE_INT8, 512, 64, kSM_75, cubin_fmha_v2_int8_512_64_sm75_cu_cubin,
cubin_fmha_v2_int8_512_64_sm75_cu_cubin_len, "fmha_v2_int8_512_64_sm75_kernel", 69632, 256, 0, false},
{DATA_TYPE_INT8, 512, 64, kSM_75, cubin_fmha_v2_int8_512_64_sm75_cu_cubin,
cubin_fmha_v2_int8_512_64_sm75_cu_cubin_len, "fmha_v2_int8_512_64_sm75_kernel_nl", 69632, 256, 32, false},
{DATA_TYPE_INT8, 512, 32, kSM_75, cubin_fmha_v2_int8_512_32_sm75_cu_cubin,
cubin_fmha_v2_int8_512_32_sm75_cu_cubin_len, "fmha_v2_int8_512_32_sm75_kernel", 36864, 256, 0, false},
{DATA_TYPE_INT8, 512, 32, kSM_75, cubin_fmha_v2_int8_512_32_sm75_cu_cubin,
cubin_fmha_v2_int8_512_32_sm75_cu_cubin_len, "fmha_v2_int8_512_32_sm75_kernel_nl", 36864, 256, 32, false},
#if CUDA_VERSION >= 11000
// {fp16} x {128} x {32} x {sm80} x {normal, noloop}
{DATA_TYPE_FP16, 128, 32, kSM_80, cubin_fmha_v2_fp16_128_32_sm80_cu_cubin,
cubin_fmha_v2_fp16_128_32_sm80_cu_cubin_len, "fmha_v2_fp16_128_32_sm80_kernel", 32768, 128, 0, false},
{DATA_TYPE_FP16, 128, 32, kSM_80, cubin_fmha_v2_fp16_128_32_sm80_cu_cubin,
cubin_fmha_v2_fp16_128_32_sm80_cu_cubin_len, "fmha_v2_fp16_128_32_sm80_kernel_nl", 20480, 128, 16, false},
// {int8} x {128} x {32} x {sm80} x {normal, noloop, interleave, interleave_noloop}
{DATA_TYPE_INT8, 128, 32, kSM_80, cubin_fmha_v2_int8_128_32_sm80_cu_cubin,
cubin_fmha_v2_int8_128_32_sm80_cu_cubin_len, "fmha_v2_int8_128_32_sm80_kernel", 16384, 128, 0, false},
{DATA_TYPE_INT8, 128, 32, kSM_80, cubin_fmha_v2_int8_128_32_sm80_cu_cubin,
cubin_fmha_v2_int8_128_32_sm80_cu_cubin_len, "fmha_v2_int8_128_32_sm80_kernel_nl", 12288, 128, 16, false},
{DATA_TYPE_INT8, 128, 32, kSM_80, cubin_fmha_v2_int8_128_32_sm80_cu_cubin,
cubin_fmha_v2_int8_128_32_sm80_cu_cubin_len, "fmha_v2_int8_128_32_sm80_kernel", 12288, 128, 0, true},
{DATA_TYPE_INT8, 128, 32, kSM_80, cubin_fmha_v2_int8_128_32_sm80_cu_cubin,
cubin_fmha_v2_int8_128_32_sm80_cu_cubin_len, "fmha_v2_il_int8_128_32_sm80_kernel_nl", 10240, 128, 16, true},
// {fp16, int8} x {256} x {32} x {sm80} x {normal, noloop}
{DATA_TYPE_INT8, 256, 32, kSM_80, cubin_fmha_v2_int8_256_32_sm80_cu_cubin,
cubin_fmha_v2_int8_256_32_sm80_cu_cubin_len, "fmha_v2_int8_256_32_sm80_kernel", 20480, 128, 0, false},
{DATA_TYPE_INT8, 256, 32, kSM_80, cubin_fmha_v2_int8_256_32_sm80_cu_cubin,
cubin_fmha_v2_int8_256_32_sm80_cu_cubin_len, "fmha_v2_int8_256_32_sm80_kernel_nl", 20480, 128, 32, false},
{DATA_TYPE_FP16, 256, 32, kSM_80, cubin_fmha_v2_fp16_256_32_sm80_cu_cubin,
cubin_fmha_v2_fp16_256_32_sm80_cu_cubin_len, "fmha_v2_fp16_256_32_sm80_kernel", 20480, 128, 0, false},
{DATA_TYPE_FP16, 256, 32, kSM_80, cubin_fmha_v2_fp16_256_32_sm80_cu_cubin,
cubin_fmha_v2_fp16_256_32_sm80_cu_cubin_len, "fmha_v2_fp16_256_32_sm80_kernel_nl", 20480, 128, 32, false},
// {int8} x {512} x {64, 32} x {sm80} x {normal, noloop}
{DATA_TYPE_INT8, 512, 64, kSM_80, cubin_fmha_v2_int8_512_64_sm80_cu_cubin,
cubin_fmha_v2_int8_512_64_sm80_cu_cubin_len, "fmha_v2_int8_512_64_sm80_kernel", 73728, 256, 0, false},
{DATA_TYPE_INT8, 512, 64, kSM_80, cubin_fmha_v2_int8_512_64_sm80_cu_cubin,
cubin_fmha_v2_int8_512_64_sm80_cu_cubin_len, "fmha_v2_int8_512_64_sm80_kernel_nl", 73728, 256, 32, false},
{DATA_TYPE_INT8, 512, 32, kSM_80, cubin_fmha_v2_int8_512_32_sm80_cu_cubin,
cubin_fmha_v2_int8_512_32_sm80_cu_cubin_len, "fmha_v2_int8_512_32_sm80_kernel", 40960, 256, 0, false},
{DATA_TYPE_INT8, 512, 32, kSM_80, cubin_fmha_v2_int8_512_32_sm80_cu_cubin,
cubin_fmha_v2_int8_512_32_sm80_cu_cubin_len, "fmha_v2_int8_512_32_sm80_kernel_nl", 40960, 256, 32, false},
// {fp16} x {512} x {64, 32} x {sm80} x {normal, noloop}
{DATA_TYPE_FP16, 512, 64, kSM_80, cubin_fmha_v2_fp16_512_64_sm80_cu_cubin,
cubin_fmha_v2_fp16_512_64_sm80_cu_cubin_len, "fmha_v2_fp16_512_64_sm80_kernel", 73728, 256, 0, false},
{DATA_TYPE_FP16, 512, 64, kSM_80, cubin_fmha_v2_fp16_512_64_sm80_cu_cubin,
cubin_fmha_v2_fp16_512_64_sm80_cu_cubin_len, "fmha_v2_fp16_512_64_sm80_kernel_nl", 73728, 256, 32, false},
{DATA_TYPE_FP16, 512, 32, kSM_80, cubin_fmha_v2_fp16_512_32_sm80_cu_cubin,
cubin_fmha_v2_fp16_512_32_sm80_cu_cubin_len, "fmha_v2_fp16_512_32_sm80_kernel", 40960, 256, 0, false},
{DATA_TYPE_FP16, 512, 32, kSM_80, cubin_fmha_v2_fp16_512_32_sm80_cu_cubin,
cubin_fmha_v2_fp16_512_32_sm80_cu_cubin_len, "fmha_v2_fp16_512_32_sm80_kernel_nl", 40960, 256, 32, false},
// Ampere
{DATA_TYPE_FP16, 64, 64, kSM_80, fused_multihead_attention_v2_fp16_64_64_kernel_sm80_cubin,
fused_multihead_attention_v2_fp16_64_64_kernel_sm80_cubin_len,
......@@ -391,7 +511,6 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
fused_multihead_attention_v2_fp16_384_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_fp16_384_64_kernel_sm80", 65536, 256, 0, false},
{DATA_TYPE_INT8, 128, 64, kSM_86, fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin,
fused_multihead_attention_v2_int8_128_64_kernel_sm86_cubin_len,
"fused_multihead_attention_v2_int8_128_64_kernel_sm80_interleaved_noloop", 20480, 128, 16, true},
......@@ -455,20 +574,20 @@ public:
{
}
inline uint64_t hashID(uint32_t s, bool interleaved, bool unroll) const
inline uint64_t hashID(uint32_t s, uint32_t headsize, bool interleaved, bool unroll) const
{
return (uint64_t) s << 32 | (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull);
// we only have 30 bits room for head size
ASSERT(headsize <= 0x3FFFFFFF);
return static_cast<uint64_t>(s) << 32 | (headsize << 2) | (interleaved ? 2U : 0U) | (unroll ? 1U : 0U);
}
virtual uint64_t hashID(const KernelMeta& kernelMeta) const
{
assert(kernelMeta.mD == 64);
return hashID(kernelMeta.mS, kernelMeta.mInterleaved, kernelMeta.mUnrollStep);
return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep);
}
virtual void run(Fused_multihead_attention_params_v2& params, cudaStream_t ss) const
{
assert(params.d == 64);
if (params.interleaved)
{
assert(mDataType == bert::DATA_TYPE_INT8);
......@@ -518,7 +637,7 @@ public:
}
}
const auto findIter = mFunctions.find(hashID(params.s, params.interleaved, forceUnroll));
const auto findIter = mFunctions.find(hashID(params.s, params.d, params.interleaved, forceUnroll));
ASSERT(findIter != mFunctions.end());
const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
......
......@@ -503,8 +503,8 @@ void UnfusedMHARunner::setup(const int S, const int B)
std::tie(mAlgoBatchedEx1, mAlgoBatchedEx2) = tuneBatchedGemm(B, S, mNumHeads, mHeadSize, mSm);
mIsBestAlgoFound = true;
gLogVerbose << "QKV Plugin - Selected Algos for batch gemms: " << mAlgoBatchedEx1 << ", " << mAlgoBatchedEx2
<< "\n";
BERT_DEBUG_VALUE("QKV Plugin - Selected Algo 1 for batch gemms: ", mAlgoBatchedEx1);
BERT_DEBUG_VALUE("QKV Plugin - Selected Algo 2 for batch gemms: ", mAlgoBatchedEx2);
}
}
......@@ -934,7 +934,7 @@ public:
warps_m = 1;
warps_n = 4;
}
else if (S == 384)
else if (S == 384 || S == 512)
{
warps_m = 1;
warps_n = 8;
......@@ -979,7 +979,8 @@ public:
params.qkv_ptr = const_cast<void*>(qkvPtr);
params.packed_mask_ptr = const_cast<void*>(maskPtr);
// dummy input in V2/V3 because now we use cu_seqlens
params.packed_mask_ptr = nullptr;
params.o_ptr = output;
......@@ -1085,7 +1086,7 @@ public:
warps_m = 1;
warps_n = 4;
}
else if (S == 384)
else if (S == 384 || S == 512)
{
warps_m = 1;
warps_n = 8;
......@@ -1131,7 +1132,8 @@ public:
params.qkv_ptr = const_cast<void*>(qkvPtr);
params.packed_mask_ptr = const_cast<void*>(maskPtr);
// dummy input in V2/V3 because now we use cu_seqlens
params.packed_mask_ptr = nullptr;
params.use_int8_scale_max = true;
......
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -23,6 +23,7 @@
#include "NvInferPlugin.h"
#include "cublas_v2.h"
#include "zeroPadding2d.h"
#include <string>
#include <vector>
......@@ -255,6 +256,7 @@ private:
std::string mNamespace;
std::unique_ptr<MHARunner> dispatcher;
std::unique_ptr<QkvPaddingRunner> patcher;
int mS;
int mB;
......
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "checkMacrosPlugin.h"
#include "zeroPadding2d.h"
#include <array>
#include <cassert>
#include <cstring>
using namespace nvinfer1;
namespace bert
{
constexpr int32_t kMAX_THREADS_PER_BLOCK{256};
template <typename TDataType>
__global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK)
zeroPadding2dKernel(const TDataType* src, int32_t spitch, TDataType* dst, int32_t dpitch, int32_t height)
{
int32_t uid = blockIdx.x * blockDim.x + threadIdx.x;
int32_t numElements = dpitch * height;
int32_t numThreads = gridDim.x * blockDim.x;
#pragma unroll
for (; uid < numElements; uid += numThreads)
{
int32_t ty = uid / dpitch;
if (ty >= height)
{
return;
}
int32_t tx = uid % dpitch;
TDataType val = 0;
if (tx < spitch)
{
val = src[ty * spitch + tx];
}
dst[ty * dpitch + tx] = val;
}
}
template <>
__global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK)
zeroPadding2dKernel(const int4* src, int32_t spitch, int4* dst, int32_t dpitch, int32_t height)
{
int32_t uid = blockIdx.x * blockDim.x + threadIdx.x;
int32_t numElements = dpitch * height;
int32_t numThreads = gridDim.x * blockDim.x;
#pragma unroll
for (; uid < numElements; uid += numThreads)
{
int32_t ty = uid / dpitch;
if (ty >= height)
{
continue;
}
int32_t tx = uid % dpitch;
int4 val{0, 0, 0, 0};
if (tx < spitch)
{
val = src[ty * spitch + tx];
}
dst[ty * dpitch + tx] = val;
}
}
cudaError_t zeroPadding2d(
const void* src, int32_t spitch, void* dst, int32_t dpitch, int32_t height, cudaStream_t stream)
{
using kernel_ptr_t = void (*)(const void* src, int32_t spitch, void* dst, int32_t dpitch, int32_t height);
kernel_ptr_t kernels[5]{reinterpret_cast<kernel_ptr_t>(zeroPadding2dKernel<int8_t>),
reinterpret_cast<kernel_ptr_t>(zeroPadding2dKernel<int16_t>),
reinterpret_cast<kernel_ptr_t>(zeroPadding2dKernel<int32_t>),
reinterpret_cast<kernel_ptr_t>(zeroPadding2dKernel<int64_t>),
reinterpret_cast<kernel_ptr_t>(zeroPadding2dKernel<int4>)};
auto select = [](size_t width) -> int32_t {
if (!(width & 0xF))
{
return 4;
}
if (!(width & 0x7))
{
return 3;
}
if (!(width & 0x3))
{
return 2;
}
if (!(width & 0x1))
{
return 1;
}
return 0;
};
auto kernelId = 4; // 128 bit access
std::array<size_t, 4> checkAlignment{reinterpret_cast<size_t>(src), static_cast<size_t>(spitch),
reinterpret_cast<size_t>(dst), static_cast<size_t>(dpitch)};
for (auto size : checkAlignment)
{
auto shiftId = select(size);
if (shiftId < kernelId)
{
kernelId = shiftId;
}
}
spitch >>= kernelId;
dpitch >>= kernelId;
int32_t devId;
CHECK_CUDA(cudaGetDevice(&devId));
int32_t numSms;
CHECK_CUDA(cudaDeviceGetAttribute(&numSms, cudaDevAttrMultiProcessorCount, devId));
auto kernel = kernels[kernelId];
int32_t block = kMAX_THREADS_PER_BLOCK;
int32_t grid = (dpitch * height + kMAX_THREADS_PER_BLOCK - 1) / kMAX_THREADS_PER_BLOCK;
int32_t blocksPerSm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocksPerSm, kernel, block, 0));
grid = std::min(numSms * blocksPerSm, grid);
kernel<<<grid, block, 0, stream>>>(src, spitch, dst, dpitch, height);
return cudaPeekAtLastError();
}
QkvPaddingRunner::QkvPaddingRunner(int32_t headSize, DataType dtype)
{
ASSERT(headSize > 0 && headSize <= 64);
mPaddingHeadSize = (headSize <= 32) ? 32 : 64;
ASSERT(dtype == DataType::kHALF || dtype == DataType::kINT8);
mDtypeSize = (dtype == DataType::kHALF) ? 2 : 1;
}
int32_t QkvPaddingRunner::getPaddingHeadSize()
{
return mPaddingHeadSize;
}
size_t QkvPaddingRunner::getInputSize(int32_t sumSeqLen, int32_t numHeads)
{
return (3U * sumSeqLen * numHeads * mPaddingHeadSize * mDtypeSize);
}
size_t QkvPaddingRunner::getOutputSize(int32_t sumSeqLen, int32_t numHeads)
{
return (1U * sumSeqLen * numHeads * mPaddingHeadSize * mDtypeSize);
}
size_t QkvPaddingRunner::getWorkspaceSize(int32_t sumSeqLen, int32_t numHeads)
{
constexpr int32_t reserveForAlignment = 16;
return getInputSize(sumSeqLen, numHeads) + getOutputSize(sumSeqLen, numHeads) + reserveForAlignment;
}
void* QkvPaddingRunner::get16BytesAlignedPointer(void* workspace, size_t offset)
{
auto addr = reinterpret_cast<uintptr_t>(workspace) + offset;
auto shift = 16 - (addr & 0xF);
if (shift == 16)
{
shift = 0;
}
return reinterpret_cast<void*>(addr + shift);
}
cudaError_t QkvPaddingRunner::pad(
const void* src, void* workspace, int32_t sumSeqLen, int32_t numHeads, int32_t headSize, cudaStream_t stream)
{
return zeroPadding2d(
src, headSize * mDtypeSize, workspace, mPaddingHeadSize * mDtypeSize, 3 * sumSeqLen * numHeads, stream);
}
cudaError_t QkvPaddingRunner::unpad(
const void* workspace, void* dst, int32_t sumSeqLen, int32_t numHeads, int32_t headSize, cudaStream_t stream)
{
return zeroPadding2d(
workspace, mPaddingHeadSize * mDtypeSize, dst, headSize * mDtypeSize, sumSeqLen * numHeads, stream);
}