未验证 提交 c9d26423 编写于 作者: Y Yang Zhang 提交者: GitHub

Fix float64 bug in layer norm (#30454)

built-in `rsqrt` is shadowed
上级 badc6f22
...@@ -108,23 +108,23 @@ struct PairForLayerNormAddFunctor { ...@@ -108,23 +108,23 @@ struct PairForLayerNormAddFunctor {
}; };
template <typename T> template <typename T>
__inline__ __device__ T rsqrt(const T val) { __inline__ __device__ T rsqrt_(const T val) {
return static_cast<T>(1) / sqrt(val); return static_cast<T>(1) / sqrt(val);
} }
template <> template <>
__inline__ __device__ float rsqrt(const float val) { __inline__ __device__ float rsqrt_(const float val) {
return rsqrtf(val); return rsqrtf(val);
} }
template <> template <>
__inline__ __device__ double rsqrt(const double val) { __inline__ __device__ double rsqrt_(const double val) {
return rsqrt(val); return rsqrt(val);
} }
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
template <> template <>
__inline__ __device__ half rsqrt(const half val) { __inline__ __device__ half rsqrt_(const half val) {
return hrsqrt(val); return hrsqrt(val);
} }
#endif #endif
...@@ -161,7 +161,7 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, ...@@ -161,7 +161,7 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
__syncthreads(); __syncthreads();
mean_val = mean_share; mean_val = mean_share;
U invvar = rsqrt<U>(var_share + static_cast<U>(epsilon)); U invvar = rsqrt_<U>(var_share + static_cast<U>(epsilon));
// Step 2: Calculate y // Step 2: Calculate y
if (scale != nullptr) { if (scale != nullptr) {
...@@ -204,7 +204,7 @@ __inline__ __device__ void cuLoadAddStridedInputs( ...@@ -204,7 +204,7 @@ __inline__ __device__ void cuLoadAddStridedInputs(
const int i1 = i1_block + thr_load_row_off; const int i1 = i1_block + thr_load_row_off;
if (i1 >= i1_end) return; if (i1 >= i1_end) return;
U curr_mean = mean[i1]; U curr_mean = mean[i1];
U curr_invvar = rsqrt<U>(var[i1] + epsilon); U curr_invvar = rsqrt_<U>(var[i1] + epsilon);
for (int k = 0; k < VPT; ++k) { for (int k = 0; k < VPT; ++k) {
const int i2 = i2_off + k; const int i2 = i2_off + k;
const int load_idx = i1 * n2 + i2; const int load_idx = i1 * n2 + i2;
...@@ -352,7 +352,7 @@ __global__ void LayerNormBackwardComputeGradInput( ...@@ -352,7 +352,7 @@ __global__ void LayerNormBackwardComputeGradInput(
U sum_loss1 = U(0); U sum_loss1 = U(0);
U sum_loss2 = U(0); U sum_loss2 = U(0);
const U c_mean = mean[i1]; const U c_mean = mean[i1];
const U c_invvar = rsqrt<U>(var[i1] + epsilon); const U c_invvar = rsqrt_<U>(var[i1] + epsilon);
const T *k_input = input + i1 * n2; const T *k_input = input + i1 * n2;
const T *k_dout = dout + i1 * n2; const T *k_dout = dout + i1 * n2;
constexpr int numx = BDIMX * BDIMY; constexpr int numx = BDIMX * BDIMY;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册