diff --git a/paddle/fluid/operators/math/bert_encoder_functor.cu b/paddle/fluid/operators/math/bert_encoder_functor.cu index 4d7218cd89e04b5122ff4385abfb2c7305e40c0a..645d1f637183c7bcec297a7f2ce3cd73a01e53c3 100644 --- a/paddle/fluid/operators/math/bert_encoder_functor.cu +++ b/paddle/fluid/operators/math/bert_encoder_functor.cu @@ -25,6 +25,14 @@ namespace paddle { namespace operators { namespace math { +template +__device__ __forceinline__ T local_rsqrt(T num) { + return rsqrt(static_cast(num)); +} +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) +__device__ __forceinline__ half local_rsqrt(half num) { return hrsqrt(num); } +#endif + template __device__ inline void LayerNormSmall(T val, const kvp &thread_data, const int ld, const int idx, @@ -39,7 +47,7 @@ __device__ inline void LayerNormSmall(T val, const kvp &thread_data, if (threadIdx.x == 0) { mu = sum_kv.key; - rsigma = rsqrt(sum_kv.value - mu * mu + eps); + rsigma = local_rsqrt(sum_kv.value - mu * mu + eps); } __syncthreads(); @@ -63,7 +71,7 @@ __device__ inline void LayerNorm(const kvp &thread_data, const int ld, if (threadIdx.x == 0) { mu = sum_kv.key; - rsigma = rsqrt(sum_kv.value - mu * mu + eps); + rsigma = local_rsqrt(sum_kv.value - mu * mu + eps); } __syncthreads(); @@ -89,7 +97,7 @@ __device__ inline void LayerNorm2(const kvp &thread_data, const int ld, if (threadIdx.x == 0) { mu = sum_kv.key; - rsigma = rsqrt(sum_kv.value - mu * mu + eps); + rsigma = local_rsqrt(sum_kv.value - mu * mu + eps); } __syncthreads();