未验证 提交 e528392d 编写于 作者: Z zlsh80826 提交者: GitHub

[Paddle-TRT] SkipLayernorm vectorized memory optimization (#25117)

* add explicit specialization

* add skiplayernorm vector load if available

* test=develop
上级 4061aa64
......@@ -75,6 +75,34 @@ __device__ inline void LayerNorm(const kvp<T> &thread_data, const int ld,
}
}
template <typename T, typename T2, int TPB>
__device__ inline void LayerNorm2(const kvp<T> &thread_data, const int ld,
const int offset, const float2 *bias,
const float2 *scale, T2 *output, T eps) {
using BlockReduce = cub::BlockReduce<kvp<T>, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T mu; // mean
__shared__ T rsigma; // 1 / std.dev.
const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, cub::Sum());
if (threadIdx.x == 0) {
mu = sum_kv.key;
rsigma = rsqrt(sum_kv.value - mu * mu + eps);
}
__syncthreads();
for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
T2 val = output[idx];
const float2 g = scale[i];
const float2 b = bias[i];
val.x = T(g.x) * (val.x - mu) * rsigma + T(b.x);
val.y = T(g.y) * (val.y - mu) * rsigma + T(b.y);
output[idx] = val;
}
}
template <typename T, unsigned TPB>
__global__ void EmbEltwiseLayernormKernel(int hidden, const int64_t *ids,
const float *scale, const float *bias,
......@@ -323,6 +351,27 @@ __global__ void SkipLayerNormKernel(int num, int hidden, const T *input1,
LayerNorm<T, TPB>(thread_data, hidden, offset, bias, scale, output, eps);
}
template <typename T, typename T2, unsigned TPB>
__global__ void SkipLayerNormKernel2(int num, int hidden, const T2 *input1,
const T2 *input2, T2 *output,
const float2 *scale, const float2 *bias,
float eps) {
const T rld = T(0.5f / hidden); // because hidden is hidden/2
const int offset = blockIdx.x * hidden;
cub::Sum pair_sum;
kvp<T> thread_data(0, 0);
for (int it = threadIdx.x; it < hidden; it += TPB) {
const int idx = offset + it;
const T2 val2 = input1[idx] + input2[idx];
thread_data = pair_sum(
thread_data, kvp<T>(rld * (val2.x + val2.y),
rld * val2.x * val2.x + rld * val2.y * val2.y));
output[idx] = val2;
}
LayerNorm2<T, T2, TPB>(thread_data, hidden, offset, bias, scale, output, eps);
}
template <typename T>
void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden,
const T *input1, const T *input2,
......@@ -344,8 +393,35 @@ void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden,
num, hidden, input1, input2, output, scale, bias, eps);
} else {
const int threads = 256;
SkipLayerNormKernel<T, threads><<<block, threads, 0, stream>>>(
num, hidden, input1, input2, output, scale, bias, eps);
if (hidden % 2 == 0) {
#ifdef SUPPORTS_CUDA_FP16
if (std::is_same<T, float>::value) {
#endif
SkipLayerNormKernel2<float, float2,
threads><<<block, threads, 0, stream>>>(
num, hidden / 2, reinterpret_cast<const float2 *>(input1),
reinterpret_cast<const float2 *>(input2),
reinterpret_cast<float2 *>(output),
reinterpret_cast<const float2 *>(scale),
reinterpret_cast<const float2 *>(bias), eps);
#ifdef SUPPORTS_CUDA_FP16
} else if (std::is_same<T, __half>::value) {
SkipLayerNormKernel2<__half, __half2,
threads><<<block, threads, 0, stream>>>(
num, hidden / 2, reinterpret_cast<const __half2 *>(input1),
reinterpret_cast<const __half2 *>(input2),
reinterpret_cast<__half2 *>(output),
reinterpret_cast<const float2 *>(scale),
reinterpret_cast<const float2 *>(bias), eps);
} else {
assert(false);
// should not be here
}
#endif
} else {
SkipLayerNormKernel<T, threads><<<block, threads, 0, stream>>>(
num, hidden, input1, input2, output, scale, bias, eps);
}
}
}
......
......@@ -66,7 +66,8 @@ __device__ __forceinline__ float2 ToFloat2<float2>(float2 a) {
}
template <>
__device__ __forceinline__ float2 FloatsToPair(const float a, const float b) {
__device__ __forceinline__ float2 FloatsToPair<float2>(const float a,
const float b) {
return make_float2(a, b);
}
......@@ -86,7 +87,8 @@ __device__ __forceinline__ float2 ToFloat2<__half2>(__half2 a) {
}
template <>
__device__ __forceinline__ __half2 FloatsToPair(const float a, const float b) {
__device__ __forceinline__ __half2 FloatsToPair<__half2>(const float a,
const float b) {
return __floats2half2_rn(a, b);
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册