未验证 提交 1f987a75 编写于 作者: J Jeng Bai-Cheng 提交者: GitHub

bugfix, read and write race at fast_ln_fwd (#56435)

上级 abf05e11
......@@ -268,10 +268,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
for (int it = 0; it < WARPS_N; ++it) {
mu_local += smem[warp_m * WARPS_N + it];
}
smem[warp_m] = mu_local;
smem[warp_m * WARPS_N] = mu_local;
}
__syncthreads();
mu_local = smem[warp_m];
mu_local = smem[warp_m * WARPS_N];
}
mu_local *= rn;
......@@ -295,6 +295,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
}
if (WARPS_N > 1) {
__syncthreads();
if (lane == 0) {
smem[warp_m * WARPS_N + warp_n] = var_local;
}
......@@ -305,10 +306,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
for (int it = 0; it < WARPS_N; ++it) {
var_local += smem[warp_m * WARPS_N + it];
}
smem[warp_m] = var_local;
smem[warp_m * WARPS_N] = var_local;
}
__syncthreads();
var_local = smem[warp_m];
var_local = smem[warp_m * WARPS_N];
}
// Note: to assure if it is right for double
......
......@@ -606,7 +606,7 @@ void LayerNormKernel(const Context &dev_ctx,
if ((feature_size >= 768 && feature_size <= 2048 && feature_size % 256 == 0 ||
feature_size == 4096) &&
scale != nullptr && bias != nullptr) {
can_call_fast_kernel = false;
can_call_fast_kernel = true;
}
if (can_call_fast_kernel) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册