未验证 提交 33540e10 编写于 作者: Z Zhang Zheng 提交者: GitHub

Fix nan in fast_ln_fwd_kernel when cols > 1024 (#44125)

* Fix nan in fast_ln_fwd_kernel when cols > 1024

* delete blas
上级 9428c969
...@@ -573,7 +573,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( ...@@ -573,7 +573,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
smem[warp_m * WARPS_N + warp_n] = mu_local; smem[warp_m * WARPS_N + warp_n] = mu_local;
} }
__syncthreads(); __syncthreads();
if (tidx == 0) { if (tidx % THREADS_PER_ROW == 0) {
mu_local = 0.f; mu_local = 0.f;
#pragma unroll #pragma unroll
for (int it = 0; it < WARPS_N; ++it) { for (int it = 0; it < WARPS_N; ++it) {
...@@ -608,7 +608,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( ...@@ -608,7 +608,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
smem[warp_m * WARPS_N + warp_n] = var_local; smem[warp_m * WARPS_N + warp_n] = var_local;
} }
__syncthreads(); __syncthreads();
if (tidx == 0) { if (tidx % THREADS_PER_ROW == 0) {
var_local = 0.f; var_local = 0.f;
#pragma unroll #pragma unroll
for (int it = 0; it < WARPS_N; ++it) { for (int it = 0; it < WARPS_N; ++it) {
......
...@@ -252,7 +252,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel( ...@@ -252,7 +252,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
smem[warp_m * WARPS_N + warp_n] = mu_local; smem[warp_m * WARPS_N + warp_n] = mu_local;
} }
__syncthreads(); __syncthreads();
if (tidx == 0) { if (tidx % THREADS_PER_ROW == 0) {
mu_local = 0.f; mu_local = 0.f;
#pragma unroll #pragma unroll
for (int it = 0; it < WARPS_N; ++it) { for (int it = 0; it < WARPS_N; ++it) {
...@@ -289,7 +289,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel( ...@@ -289,7 +289,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
smem[warp_m * WARPS_N + warp_n] = var_local; smem[warp_m * WARPS_N + warp_n] = var_local;
} }
__syncthreads(); __syncthreads();
if (tidx == 0) { if (tidx % THREADS_PER_ROW == 0) {
var_local = 0.f; var_local = 0.f;
#pragma unroll #pragma unroll
for (int it = 0; it < WARPS_N; ++it) { for (int it = 0; it < WARPS_N; ++it) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册