未验证 提交 4e036fa1 编写于 作者: Z Zeng Jinle 提交者: GitHub

refine rescale_grad (#36490)

上级 314cc495
...@@ -160,7 +160,6 @@ __global__ void L2NormKernel( ...@@ -160,7 +160,6 @@ __global__ void L2NormKernel(
__shared__ MT s_buffer[2]; __shared__ MT s_buffer[2];
int tid = threadIdx.x + blockDim.x * blockIdx.x; int tid = threadIdx.x + blockDim.x * blockIdx.x;
int grid_stride = LARS_BLOCK_SIZE * gridDim.x; int grid_stride = LARS_BLOCK_SIZE * gridDim.x;
const MT rescale_pow = rescale_grad * rescale_grad;
MT p_tmp = static_cast<MT>(0); MT p_tmp = static_cast<MT>(0);
MT g_tmp = static_cast<MT>(0); MT g_tmp = static_cast<MT>(0);
...@@ -190,7 +189,7 @@ __global__ void L2NormKernel( ...@@ -190,7 +189,7 @@ __global__ void L2NormKernel(
} }
__syncthreads(); __syncthreads();
*p_n = Sqrt(s_buffer[0]); *p_n = Sqrt(s_buffer[0]);
*g_n = Sqrt(rescale_pow * s_buffer[1]); *g_n = rescale_grad * Sqrt(s_buffer[1]);
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册