diff --git a/paddle/fluid/operators/math/bert_encoder_functor.cu b/paddle/fluid/operators/math/bert_encoder_functor.cu index 9274146290d5f3be7cf1a67a53267d2e82c82ee8..59a79bcb699307b1be81a8cb54006f3daebe7fb9 100644 --- a/paddle/fluid/operators/math/bert_encoder_functor.cu +++ b/paddle/fluid/operators/math/bert_encoder_functor.cu @@ -75,6 +75,34 @@ __device__ inline void LayerNorm(const kvp &thread_data, const int ld, } } +template +__device__ inline void LayerNorm2(const kvp &thread_data, const int ld, + const int offset, const float2 *bias, + const float2 *scale, T2 *output, T eps) { + using BlockReduce = cub::BlockReduce, TPB>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ T mu; // mean + __shared__ T rsigma; // 1 / std.dev. + + const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, cub::Sum()); + + if (threadIdx.x == 0) { + mu = sum_kv.key; + rsigma = rsqrt(sum_kv.value - mu * mu + eps); + } + __syncthreads(); + + for (int i = threadIdx.x; i < ld; i += TPB) { + const int idx = offset + i; + T2 val = output[idx]; + const float2 g = scale[i]; + const float2 b = bias[i]; + val.x = T(g.x) * (val.x - mu) * rsigma + T(b.x); + val.y = T(g.y) * (val.y - mu) * rsigma + T(b.y); + output[idx] = val; + } +} + template __global__ void EmbEltwiseLayernormKernel(int hidden, const int64_t *ids, const float *scale, const float *bias, @@ -323,6 +351,27 @@ __global__ void SkipLayerNormKernel(int num, int hidden, const T *input1, LayerNorm(thread_data, hidden, offset, bias, scale, output, eps); } +template +__global__ void SkipLayerNormKernel2(int num, int hidden, const T2 *input1, + const T2 *input2, T2 *output, + const float2 *scale, const float2 *bias, + float eps) { + const T rld = T(0.5f / hidden); // because hidden is hidden/2 + const int offset = blockIdx.x * hidden; + cub::Sum pair_sum; + kvp thread_data(0, 0); + + for (int it = threadIdx.x; it < hidden; it += TPB) { + const int idx = offset + it; + const T2 val2 = input1[idx] + input2[idx]; + thread_data = pair_sum( + thread_data, kvp(rld * (val2.x + val2.y), + rld * val2.x * val2.x + rld * val2.y * val2.y)); + output[idx] = val2; + } + LayerNorm2(thread_data, hidden, offset, bias, scale, output, eps); +} + template void SkipLayerNormFunctor::operator()(const int num, const int hidden, const T *input1, const T *input2, @@ -344,8 +393,35 @@ void SkipLayerNormFunctor::operator()(const int num, const int hidden, num, hidden, input1, input2, output, scale, bias, eps); } else { const int threads = 256; - SkipLayerNormKernel<<>>( - num, hidden, input1, input2, output, scale, bias, eps); + if (hidden % 2 == 0) { +#ifdef SUPPORTS_CUDA_FP16 + if (std::is_same::value) { +#endif + SkipLayerNormKernel2<<>>( + num, hidden / 2, reinterpret_cast(input1), + reinterpret_cast(input2), + reinterpret_cast(output), + reinterpret_cast(scale), + reinterpret_cast(bias), eps); +#ifdef SUPPORTS_CUDA_FP16 + } else if (std::is_same::value) { + SkipLayerNormKernel2<__half, __half2, + threads><<>>( + num, hidden / 2, reinterpret_cast(input1), + reinterpret_cast(input2), + reinterpret_cast<__half2 *>(output), + reinterpret_cast(scale), + reinterpret_cast(bias), eps); + } else { + assert(false); + // should not be here + } +#endif + } else { + SkipLayerNormKernel<<>>( + num, hidden, input1, input2, output, scale, bias, eps); + } } } diff --git a/paddle/fluid/operators/math/math_cuda_utils.h b/paddle/fluid/operators/math/math_cuda_utils.h index 0325717b4d3714e8eae260beb89df7f2addda88f..1149914efbca4613757b3402624dd9ce3f62625f 100644 --- a/paddle/fluid/operators/math/math_cuda_utils.h +++ b/paddle/fluid/operators/math/math_cuda_utils.h @@ -66,7 +66,8 @@ __device__ __forceinline__ float2 ToFloat2(float2 a) { } template <> -__device__ __forceinline__ float2 FloatsToPair(const float a, const float b) { +__device__ __forceinline__ float2 FloatsToPair(const float a, + const float b) { return make_float2(a, b); } @@ -86,7 +87,8 @@ __device__ __forceinline__ float2 ToFloat2<__half2>(__half2 a) { } template <> -__device__ __forceinline__ __half2 FloatsToPair(const float a, const float b) { +__device__ __forceinline__ __half2 FloatsToPair<__half2>(const float a, + const float b) { return __floats2half2_rn(a, b); } #endif