diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index 89326679d5d5010a12a279f8fa017ea58c71e9d5..2c27a2135c14b272a8730c3876fdd1af1690c0ea 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -160,7 +160,6 @@ __global__ void L2NormKernel( __shared__ MT s_buffer[2]; int tid = threadIdx.x + blockDim.x * blockIdx.x; int grid_stride = LARS_BLOCK_SIZE * gridDim.x; - const MT rescale_pow = rescale_grad * rescale_grad; MT p_tmp = static_cast(0); MT g_tmp = static_cast(0); @@ -190,7 +189,7 @@ __global__ void L2NormKernel( } __syncthreads(); *p_n = Sqrt(s_buffer[0]); - *g_n = Sqrt(rescale_pow * s_buffer[1]); + *g_n = rescale_grad * Sqrt(s_buffer[1]); #endif }