diff --git a/paddle/phi/kernels/funcs/layer_norm_impl.cu.h b/paddle/phi/kernels/funcs/layer_norm_impl.cu.h index e2d908b853188f51136b58dffe22219c924973d0..778f13634300b1e86326f191802c1590255c7499 100644 --- a/paddle/phi/kernels/funcs/layer_norm_impl.cu.h +++ b/paddle/phi/kernels/funcs/layer_norm_impl.cu.h @@ -268,10 +268,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel( for (int it = 0; it < WARPS_N; ++it) { mu_local += smem[warp_m * WARPS_N + it]; } - smem[warp_m] = mu_local; + smem[warp_m * WARPS_N] = mu_local; } __syncthreads(); - mu_local = smem[warp_m]; + mu_local = smem[warp_m * WARPS_N]; } mu_local *= rn; @@ -295,6 +295,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel( } if (WARPS_N > 1) { + __syncthreads(); if (lane == 0) { smem[warp_m * WARPS_N + warp_n] = var_local; } @@ -305,10 +306,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel( for (int it = 0; it < WARPS_N; ++it) { var_local += smem[warp_m * WARPS_N + it]; } - smem[warp_m] = var_local; + smem[warp_m * WARPS_N] = var_local; } __syncthreads(); - var_local = smem[warp_m]; + var_local = smem[warp_m * WARPS_N]; } // Note: to assure if it is right for double diff --git a/paddle/phi/kernels/gpu/layer_norm_kernel.cu b/paddle/phi/kernels/gpu/layer_norm_kernel.cu index 9ba6dfd4235358ff091abcfbfb6ca4f6b3875e8a..c5bb0c288f260968c7f163ccbc8db63c2ef45ba8 100644 --- a/paddle/phi/kernels/gpu/layer_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/layer_norm_kernel.cu @@ -606,7 +606,7 @@ void LayerNormKernel(const Context &dev_ctx, if ((feature_size >= 768 && feature_size <= 2048 && feature_size % 256 == 0 || feature_size == 4096) && scale != nullptr && bias != nullptr) { - can_call_fast_kernel = false; + can_call_fast_kernel = true; } if (can_call_fast_kernel) {