未验证 提交 1f987a75 编写于 作者: J Jeng Bai-Cheng 提交者: GitHub

bugfix, read and write race at fast_ln_fwd (#56435)

上级 abf05e11
...@@ -268,10 +268,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel( ...@@ -268,10 +268,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
for (int it = 0; it < WARPS_N; ++it) { for (int it = 0; it < WARPS_N; ++it) {
mu_local += smem[warp_m * WARPS_N + it]; mu_local += smem[warp_m * WARPS_N + it];
} }
smem[warp_m] = mu_local; smem[warp_m * WARPS_N] = mu_local;
} }
__syncthreads(); __syncthreads();
mu_local = smem[warp_m]; mu_local = smem[warp_m * WARPS_N];
} }
mu_local *= rn; mu_local *= rn;
...@@ -295,6 +295,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel( ...@@ -295,6 +295,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
} }
if (WARPS_N > 1) { if (WARPS_N > 1) {
__syncthreads();
if (lane == 0) { if (lane == 0) {
smem[warp_m * WARPS_N + warp_n] = var_local; smem[warp_m * WARPS_N + warp_n] = var_local;
} }
...@@ -305,10 +306,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel( ...@@ -305,10 +306,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
for (int it = 0; it < WARPS_N; ++it) { for (int it = 0; it < WARPS_N; ++it) {
var_local += smem[warp_m * WARPS_N + it]; var_local += smem[warp_m * WARPS_N + it];
} }
smem[warp_m] = var_local; smem[warp_m * WARPS_N] = var_local;
} }
__syncthreads(); __syncthreads();
var_local = smem[warp_m]; var_local = smem[warp_m * WARPS_N];
} }
// Note: to assure if it is right for double // Note: to assure if it is right for double
......
...@@ -606,7 +606,7 @@ void LayerNormKernel(const Context &dev_ctx, ...@@ -606,7 +606,7 @@ void LayerNormKernel(const Context &dev_ctx,
if ((feature_size >= 768 && feature_size <= 2048 && feature_size % 256 == 0 || if ((feature_size >= 768 && feature_size <= 2048 && feature_size % 256 == 0 ||
feature_size == 4096) && feature_size == 4096) &&
scale != nullptr && bias != nullptr) { scale != nullptr && bias != nullptr) {
can_call_fast_kernel = false; can_call_fast_kernel = true;
} }
if (can_call_fast_kernel) { if (can_call_fast_kernel) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册