From 8256f6fa862e8f46dbd162de8f65939c5f6eeaa9 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 14 Oct 2021 17:53:40 +0800 Subject: [PATCH] fix lars (#36431) --- .../operators/optimizers/lars_momentum_op.cu | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index e90f1136fd3..b640e62221f 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -165,8 +165,10 @@ __global__ void L2NormKernel( int tid = threadIdx.x + blockDim.x * blockIdx.x; int grid_stride = LARS_BLOCK_SIZE * gridDim.x; const MT rescale_pow = rescale_grad * rescale_grad; - s_buffer[0] = static_cast(0); - s_buffer[1] = static_cast(0); + if (threadIdx.x == 0) { + s_buffer[0] = static_cast(0); + s_buffer[1] = static_cast(0); + } MT p_tmp = static_cast(0); MT g_tmp = static_cast(0); @@ -175,8 +177,12 @@ __global__ void L2NormKernel( p_tmp = static_cast(p_data[tid]); g_tmp = static_cast(g_data[tid]); } - s_buffer[0] += math::blockReduceSum(p_tmp * p_tmp, FINAL_MASK); - s_buffer[1] += math::blockReduceSum(g_tmp * g_tmp, FINAL_MASK); + MT tmp0 = math::blockReduceSum(p_tmp * p_tmp, FINAL_MASK); + MT tmp1 = math::blockReduceSum(g_tmp * g_tmp, FINAL_MASK); + if (threadIdx.x == 0) { + s_buffer[0] += tmp0; + s_buffer[1] += tmp1; + } } else { /* Avoid occupy too much temp buffer. Slice the whole data into 2 parts, the front of data whose quantity is excatly multiple of grid-thread @@ -185,8 +191,12 @@ __global__ void L2NormKernel( p_tmp = static_cast(p_data[tid]); g_tmp = static_cast(g_data[tid]); tid += grid_stride; - s_buffer[0] += math::blockReduceSum(p_tmp * p_tmp, FINAL_MASK); - s_buffer[1] += math::blockReduceSum(g_tmp * g_tmp, FINAL_MASK); + MT tmp0 = math::blockReduceSum(p_tmp * p_tmp, FINAL_MASK); + MT tmp1 = math::blockReduceSum(g_tmp * g_tmp, FINAL_MASK); + if (threadIdx.x == 0) { + s_buffer[0] += tmp0; + s_buffer[1] += tmp1; + } __syncthreads(); } MT p_val = 0; @@ -195,8 +205,12 @@ __global__ void L2NormKernel( p_val = static_cast(p_data[tid]); g_val = static_cast(g_data[tid]); } - s_buffer[0] += math::blockReduceSum(p_val * p_val, FINAL_MASK); - s_buffer[1] += math::blockReduceSum(g_val * g_val, FINAL_MASK); + MT tmp0 = math::blockReduceSum(p_val * p_val, FINAL_MASK); + MT tmp1 = math::blockReduceSum(g_val * g_val, FINAL_MASK); + if (threadIdx.x == 0) { + s_buffer[0] += tmp0; + s_buffer[1] += tmp1; + } } __syncthreads(); @@ -208,8 +222,15 @@ __global__ void L2NormKernel( cg->sync(); // Grid sync for writring partial result to gloabl memory MT p_part_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0; MT g_part_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0; - *p_n = Sqrt(math::blockReduceSum(p_part_sum, FINAL_MASK)); - *g_n = Sqrt(rescale_pow * math::blockReduceSum(g_part_sum, FINAL_MASK)); + MT tmp0 = math::blockReduceSum(p_part_sum, FINAL_MASK); + MT tmp1 = math::blockReduceSum(g_part_sum, FINAL_MASK); + if (threadIdx.x == 0) { + s_buffer[0] = tmp0; + s_buffer[1] = tmp1; + } + __syncthreads(); + *p_n = Sqrt(s_buffer[0]); + *g_n = Sqrt(rescale_pow * s_buffer[1]); #endif } -- GitLab