未验证 提交 33540e10 编写于 作者: Z Zhang Zheng 提交者: GitHub

Fix nan in fast_ln_fwd_kernel when cols > 1024 (#44125)

* Fix nan in fast_ln_fwd_kernel when cols > 1024

* delete blas
上级 9428c969
......@@ -573,7 +573,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
smem[warp_m * WARPS_N + warp_n] = mu_local;
}
__syncthreads();
if (tidx == 0) {
if (tidx % THREADS_PER_ROW == 0) {
mu_local = 0.f;
#pragma unroll
for (int it = 0; it < WARPS_N; ++it) {
......@@ -608,7 +608,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
smem[warp_m * WARPS_N + warp_n] = var_local;
}
__syncthreads();
if (tidx == 0) {
if (tidx % THREADS_PER_ROW == 0) {
var_local = 0.f;
#pragma unroll
for (int it = 0; it < WARPS_N; ++it) {
......
......@@ -252,7 +252,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
smem[warp_m * WARPS_N + warp_n] = mu_local;
}
__syncthreads();
if (tidx == 0) {
if (tidx % THREADS_PER_ROW == 0) {
mu_local = 0.f;
#pragma unroll
for (int it = 0; it < WARPS_N; ++it) {
......@@ -289,7 +289,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
smem[warp_m * WARPS_N + warp_n] = var_local;
}
__syncthreads();
if (tidx == 0) {
if (tidx % THREADS_PER_ROW == 0) {
var_local = 0.f;
#pragma unroll
for (int it = 0; it < WARPS_N; ++it) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册