diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index b640e62221f77775789496eea17d55a5cefcc90e..89326679d5d5010a12a279f8fa017ea58c71e9d5 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -84,22 +84,18 @@ class LarsThreadConfig { template __device__ inline void VectorizeLarsUpdate( - const T* __restrict__ grad, const MT* __restrict__ param, - const MT* __restrict__ velocity, T* __restrict__ param_out, - MT* __restrict__ velocity_out, const MT mu, MT local_lr, + const T* __restrict__ grad, const MT* param, const MT* velocity, + T* param_out, MT* velocity_out, const MT mu, MT local_lr, const MT lars_weight_decay, const MT rescale_grad, const int tid, - const int grid_stride, const int numel, - MT* __restrict__ master_param_out = nullptr) { + const int grid_stride, const int numel, MT* master_param_out = nullptr) { using VecType = paddle::platform::AlignedVector; using VecMType = paddle::platform::AlignedVector; int main = numel >> (VecSize >> 1); int tail_offset = main * VecSize; - const VecType* __restrict__ grad_vec = reinterpret_cast(grad); - const VecMType* __restrict__ param_vec = - reinterpret_cast(param); - const VecMType* __restrict__ velocity_vec = - reinterpret_cast(velocity); + const VecType* grad_vec = reinterpret_cast(grad); + const VecMType* param_vec = reinterpret_cast(param); + const VecMType* velocity_vec = reinterpret_cast(velocity); VecType* param_out_vec = reinterpret_cast(param_out); VecMType* velocity_out_vec = reinterpret_cast(velocity_out); @@ -157,66 +153,30 @@ __forceinline__ __device__ void L2NormKernel( template __global__ void L2NormKernel( #endif - const T* __restrict__ p_data, const T* __restrict__ g_data, - MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const int64_t numel, - const int repeat_times, const MT rescale_grad, const int thresh = 0, - MT* __restrict__ p_n = nullptr, MT* __restrict__ g_n = nullptr) { + const T* p_data, const T* __restrict__ g_data, MT* __restrict__ p_buffer, + MT* __restrict__ g_buffer, const int64_t numel, const int repeat_times, + const MT rescale_grad, const int thresh = 0, MT* __restrict__ p_n = nullptr, + MT* __restrict__ g_n = nullptr) { __shared__ MT s_buffer[2]; int tid = threadIdx.x + blockDim.x * blockIdx.x; int grid_stride = LARS_BLOCK_SIZE * gridDim.x; const MT rescale_pow = rescale_grad * rescale_grad; - if (threadIdx.x == 0) { - s_buffer[0] = static_cast(0); - s_buffer[1] = static_cast(0); - } + MT p_tmp = static_cast(0); MT g_tmp = static_cast(0); - - if (repeat_times == 0) { - if (tid < numel) { - p_tmp = static_cast(p_data[tid]); - g_tmp = static_cast(g_data[tid]); - } - MT tmp0 = math::blockReduceSum(p_tmp * p_tmp, FINAL_MASK); - MT tmp1 = math::blockReduceSum(g_tmp * g_tmp, FINAL_MASK); - if (threadIdx.x == 0) { - s_buffer[0] += tmp0; - s_buffer[1] += tmp1; - } - } else { - /* Avoid occupy too much temp buffer. Slice the whole data into 2 parts, - the front of data whose quantity is excatly multiple of grid-thread - number, and delt in for loop, the rest is delt with another step. */ - for (int i = 0; i < repeat_times; ++i) { - p_tmp = static_cast(p_data[tid]); - g_tmp = static_cast(g_data[tid]); - tid += grid_stride; - MT tmp0 = math::blockReduceSum(p_tmp * p_tmp, FINAL_MASK); - MT tmp1 = math::blockReduceSum(g_tmp * g_tmp, FINAL_MASK); - if (threadIdx.x == 0) { - s_buffer[0] += tmp0; - s_buffer[1] += tmp1; - } - __syncthreads(); - } - MT p_val = 0; - MT g_val = 0; - if (tid < numel) { - p_val = static_cast(p_data[tid]); - g_val = static_cast(g_data[tid]); - } - MT tmp0 = math::blockReduceSum(p_val * p_val, FINAL_MASK); - MT tmp1 = math::blockReduceSum(g_val * g_val, FINAL_MASK); - if (threadIdx.x == 0) { - s_buffer[0] += tmp0; - s_buffer[1] += tmp1; - } + while (tid < numel) { + MT tmp0 = static_cast(p_data[tid]); + MT tmp1 = static_cast(g_data[tid]); + p_tmp += (tmp0 * tmp0); + g_tmp += (tmp1 * tmp1); + tid += grid_stride; } - __syncthreads(); + p_tmp = math::blockReduceSum(p_tmp, FINAL_MASK); + g_tmp = math::blockReduceSum(g_tmp, FINAL_MASK); if (threadIdx.x == 0) { - p_buffer[blockIdx.x] = s_buffer[0]; - g_buffer[blockIdx.x] = s_buffer[1]; + p_buffer[blockIdx.x] = p_tmp; + g_buffer[blockIdx.x] = g_tmp; } #if CUDA_VERSION >= 11000 cg->sync(); // Grid sync for writring partial result to gloabl memory @@ -236,10 +196,9 @@ __global__ void L2NormKernel( template __forceinline__ __device__ void MomentumUpdate( - const T* __restrict__ param, const T* __restrict__ grad, - const MT* __restrict__ velocity, T* param_out, MT* velocity_out, - const MT* __restrict__ master_param, MT* __restrict__ master_param_out, - const MT* __restrict__ learning_rate, const MT mu, + const T* param, const T* __restrict__ grad, const MT* velocity, + T* param_out, MT* velocity_out, const MT* master_param, + MT* master_param_out, const MT* __restrict__ learning_rate, const MT mu, const MT lars_weight_decay, const MT lars_coeff, const MT epsilon, const MT rescale_grad, const MT param_norm, const MT grad_norm, const int tid, const int grid_stride, const int64_t numel, @@ -316,14 +275,13 @@ __global__ void MergedMomentumLarsKernel(LarsParamWarpper lars_warpper, template __global__ void MomentumLarsKernel( - const T* __restrict__ param, const T* __restrict__ grad, - const MT* __restrict__ velocity, T* param_out, MT* velocity_out, - const MT* __restrict__ master_param, MT* __restrict__ master_param_out, - const MT* __restrict__ learning_rate, MT* __restrict__ p_buffer, - MT* __restrict__ g_buffer, const MT mu, const MT lars_coeff, - const MT lars_weight_decay, const MT epsilon, const MT rescale_grad, - const int repeat_times, const int thresh, const int64_t numel, - const bool is_amp) { + const T* param, const T* __restrict__ grad, const MT* velocity, + T* param_out, MT* velocity_out, const MT* master_param, + MT* master_param_out, const MT* __restrict__ learning_rate, + MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const MT mu, + const MT lars_coeff, const MT lars_weight_decay, const MT epsilon, + const MT rescale_grad, const int repeat_times, const int thresh, + const int64_t numel, const bool is_amp) { int tid = threadIdx.x + blockIdx.x * blockDim.x; int grid_stride = gridDim.x * LARS_BLOCK_SIZE; #if CUDA_VERSION >= 11000