From c9d26423b3ec9f390d3ed2c7f58fd4b826f3b4b3 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Fri, 15 Jan 2021 18:53:59 +0800 Subject: [PATCH] Fix float64 bug in layer norm (#30454) built-in `rsqrt` is shadowed --- paddle/fluid/operators/layer_norm_op.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu index ad15b18d7fe..6883ba009c5 100644 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -108,23 +108,23 @@ struct PairForLayerNormAddFunctor { }; template -__inline__ __device__ T rsqrt(const T val) { +__inline__ __device__ T rsqrt_(const T val) { return static_cast(1) / sqrt(val); } template <> -__inline__ __device__ float rsqrt(const float val) { +__inline__ __device__ float rsqrt_(const float val) { return rsqrtf(val); } template <> -__inline__ __device__ double rsqrt(const double val) { +__inline__ __device__ double rsqrt_(const double val) { return rsqrt(val); } #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) template <> -__inline__ __device__ half rsqrt(const half val) { +__inline__ __device__ half rsqrt_(const half val) { return hrsqrt(val); } #endif @@ -161,7 +161,7 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, __syncthreads(); mean_val = mean_share; - U invvar = rsqrt(var_share + static_cast(epsilon)); + U invvar = rsqrt_(var_share + static_cast(epsilon)); // Step 2: Calculate y if (scale != nullptr) { @@ -204,7 +204,7 @@ __inline__ __device__ void cuLoadAddStridedInputs( const int i1 = i1_block + thr_load_row_off; if (i1 >= i1_end) return; U curr_mean = mean[i1]; - U curr_invvar = rsqrt(var[i1] + epsilon); + U curr_invvar = rsqrt_(var[i1] + epsilon); for (int k = 0; k < VPT; ++k) { const int i2 = i2_off + k; const int load_idx = i1 * n2 + i2; @@ -352,7 +352,7 @@ __global__ void LayerNormBackwardComputeGradInput( U sum_loss1 = U(0); U sum_loss2 = U(0); const U c_mean = mean[i1]; - const U c_invvar = rsqrt(var[i1] + epsilon); + const U c_invvar = rsqrt_(var[i1] + epsilon); const T *k_input = input + i1 * n2; const T *k_dout = dout + i1 * n2; constexpr int numx = BDIMX * BDIMY; -- GitLab