未验证 提交 4e036fa1 编写于 作者: Z Zeng Jinle 提交者: GitHub

refine rescale_grad (#36490)

上级 314cc495
......@@ -160,7 +160,6 @@ __global__ void L2NormKernel(
__shared__ MT s_buffer[2];
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int grid_stride = LARS_BLOCK_SIZE * gridDim.x;
const MT rescale_pow = rescale_grad * rescale_grad;
MT p_tmp = static_cast<MT>(0);
MT g_tmp = static_cast<MT>(0);
......@@ -190,7 +189,7 @@ __global__ void L2NormKernel(
}
__syncthreads();
*p_n = Sqrt(s_buffer[0]);
*g_n = Sqrt(rescale_pow * s_buffer[1]);
*g_n = rescale_grad * Sqrt(s_buffer[1]);
#endif
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册