diff --git a/csrc/transformer/normalize_kernels.cu b/csrc/transformer/normalize_kernels.cu index a0e512c73d44d4e8942468102895d2f5e93094a0..366e9372463830bf92a87c835b7ea919c564247a 100644 --- a/csrc/transformer/normalize_kernels.cu +++ b/csrc/transformer/normalize_kernels.cu @@ -624,9 +624,8 @@ __global__ void LayerNormBackward1(const T* __restrict__ out_grad, int offset = threadIdx.y * width + idx; int y_stride = width * TILE_DIM; - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - float betta_reg = (invertible ? (float)betta[pos] : 0.0f); - float gamma_reg = (float)gamma[pos]; + float betta_reg = (invertible ? (float)betta[idx] : 0.0f); + float gamma_reg = (float)gamma[idx]; // Loop across matrix height float betta_tmp = 0; @@ -660,6 +659,7 @@ __global__ void LayerNormBackward1(const T* __restrict__ out_grad, } if (threadIdx.x == 0) { + int pos = blockIdx.x * TILE_DIM + threadIdx.y; betta_grad[pos] = s1; gamma_grad[pos] = s2; } @@ -1368,9 +1368,8 @@ __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1, int offset = threadIdx.y * width + idx; int y_stride = width * TILE_DIM; - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - float betta_reg = (invertible ? (float)betta[pos] : 0.0f); - float gamma_reg = (float)gamma[pos]; + float betta_reg = (invertible ? (float)betta[idx] : 0.0f); + float gamma_reg = (float)gamma[idx]; // Loop across matrix height float betta_tmp = 0; @@ -1404,6 +1403,7 @@ __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1, } if (threadIdx.x == 0) { + int pos = blockIdx.x * TILE_DIM + threadIdx.y; betta_grad[pos] = s1; gamma_grad[pos] = s2; }