未验证 提交 addd5fce 编写于 作者: W wenbin 提交者: GitHub

miss format (#34771)

上级 4d2994cb
......@@ -25,6 +25,14 @@ namespace paddle {
namespace operators {
namespace math {
template <typename T>
__device__ __forceinline__ T local_rsqrt(T num) {
return rsqrt(static_cast<float>(num));
}
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
__device__ __forceinline__ half local_rsqrt(half num) { return hrsqrt(num); }
#endif
template <typename T, int TPB>
__device__ inline void LayerNormSmall(T val, const kvp<T> &thread_data,
const int ld, const int idx,
......@@ -39,7 +47,7 @@ __device__ inline void LayerNormSmall(T val, const kvp<T> &thread_data,
if (threadIdx.x == 0) {
mu = sum_kv.key;
rsigma = rsqrt(sum_kv.value - mu * mu + eps);
rsigma = local_rsqrt(sum_kv.value - mu * mu + eps);
}
__syncthreads();
......@@ -63,7 +71,7 @@ __device__ inline void LayerNorm(const kvp<T> &thread_data, const int ld,
if (threadIdx.x == 0) {
mu = sum_kv.key;
rsigma = rsqrt(sum_kv.value - mu * mu + eps);
rsigma = local_rsqrt(sum_kv.value - mu * mu + eps);
}
__syncthreads();
......@@ -89,7 +97,7 @@ __device__ inline void LayerNorm2(const kvp<T> &thread_data, const int ld,
if (threadIdx.x == 0) {
mu = sum_kv.key;
rsigma = rsqrt(sum_kv.value - mu * mu + eps);
rsigma = local_rsqrt(sum_kv.value - mu * mu + eps);
}
__syncthreads();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册