From addd5fceb28e8e7704a6afd73a358369a1c31f62 Mon Sep 17 00:00:00 2001 From: wenbin Date: Wed, 11 Aug 2021 14:06:08 +0800 Subject: [PATCH] miss format (#34771) --- .../fluid/operators/math/bert_encoder_functor.cu | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/math/bert_encoder_functor.cu b/paddle/fluid/operators/math/bert_encoder_functor.cu index 4d7218cd89e..645d1f63718 100644 --- a/paddle/fluid/operators/math/bert_encoder_functor.cu +++ b/paddle/fluid/operators/math/bert_encoder_functor.cu @@ -25,6 +25,14 @@ namespace paddle { namespace operators { namespace math { +template +__device__ __forceinline__ T local_rsqrt(T num) { + return rsqrt(static_cast(num)); +} +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) +__device__ __forceinline__ half local_rsqrt(half num) { return hrsqrt(num); } +#endif + template __device__ inline void LayerNormSmall(T val, const kvp &thread_data, const int ld, const int idx, @@ -39,7 +47,7 @@ __device__ inline void LayerNormSmall(T val, const kvp &thread_data, if (threadIdx.x == 0) { mu = sum_kv.key; - rsigma = rsqrt(sum_kv.value - mu * mu + eps); + rsigma = local_rsqrt(sum_kv.value - mu * mu + eps); } __syncthreads(); @@ -63,7 +71,7 @@ __device__ inline void LayerNorm(const kvp &thread_data, const int ld, if (threadIdx.x == 0) { mu = sum_kv.key; - rsigma = rsqrt(sum_kv.value - mu * mu + eps); + rsigma = local_rsqrt(sum_kv.value - mu * mu + eps); } __syncthreads(); @@ -89,7 +97,7 @@ __device__ inline void LayerNorm2(const kvp &thread_data, const int ld, if (threadIdx.x == 0) { mu = sum_kv.key; - rsigma = rsqrt(sum_kv.value - mu * mu + eps); + rsigma = local_rsqrt(sum_kv.value - mu * mu + eps); } __syncthreads(); -- GitLab