未验证 提交 5221832e 编写于 作者: Taka152's avatar Taka152 提交者: GitHub

Fix wrong idx bug in invertible LayerNormBackward1 (#692)

* fix wrong idx bug in invertible LayerNormBackward1

this index bug cause wrong scale grad

* fix unexpected deletion

* fix idx for LayerNormBackward1_fused_add

* move pos defination in LayerNormBackward1 kernels

* fix format error
Co-authored-by: NReza Yazdani <reyazda@microsoft.com>
上级 852c524a
......@@ -624,9 +624,8 @@ __global__ void LayerNormBackward1(const T* __restrict__ out_grad,
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
float betta_reg = (invertible ? (float)betta[pos] : 0.0f);
float gamma_reg = (float)gamma[pos];
float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
float gamma_reg = (float)gamma[idx];
// Loop across matrix height
float betta_tmp = 0;
......@@ -660,6 +659,7 @@ __global__ void LayerNormBackward1(const T* __restrict__ out_grad,
}
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
......@@ -1368,9 +1368,8 @@ __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
float betta_reg = (invertible ? (float)betta[pos] : 0.0f);
float gamma_reg = (float)gamma[pos];
float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
float gamma_reg = (float)gamma[idx];
// Loop across matrix height
float betta_tmp = 0;
......@@ -1404,6 +1403,7 @@ __global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
}
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册