From d17961edc0f32f640861db93ed2e8660062ba2b7 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 1 Mar 2022 09:55:33 +0800 Subject: [PATCH] Optimize the CUDA kernel in DistributedFusedLamb optimizer (#39972) * vectorize lamb kernel * remove flags, add ut * remove useless codes * refine code, add param order --- .../distributed_fused_lamb_init_op.cc | 39 +- .../distributed_fused_lamb_init_op.cu | 162 ++--- .../optimizers/distributed_fused_lamb_op.cc | 34 +- .../optimizers/distributed_fused_lamb_op.cu | 682 ++++++++++-------- .../operators/optimizers/multi_tensor_apply.h | 61 +- .../distributed_fused_lamb_test_base.py | 5 + .../optimizer/distributed_fused_lamb.py | 21 +- 7 files changed, 546 insertions(+), 458 deletions(-) diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc index 28c6efef141..efec50efa92 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cc @@ -61,30 +61,31 @@ class DistributedFusedLambInitOpMaker "The fp32 beta1 power accumulator tensor. Its shape is [1]."); AddOutput("Beta2Pow", "The fp32 beta2 power accumulator tensor. Its shape is [1]."); - AddOutput("FusedIndices", - "The param index of each element in FP32FusedParam. Its shape is " - "[M1+M2]. It is like [0,0,0,1,1,1,1,2,2,...]."); AddOutput( "FusedParamOffsets", "The numel offset of each parameter inside the FP32FusedParam. Its " "shape is [param_num + 1]. It is like [0, n_0, n_0 + n_1, n_0 + n_1 " - "+ n_2, ...]."); - AddOutput("FP32ShardFusedParamOffsets", - "The sharded numel offset of each parameter in the local rank. " - "Its shape is [fp32_local_param_num + 1]."); - AddOutput("FP16ShardFusedParamOffsets", - "The sharded numel offset of each parameter in the local rank. " - "Its shape is [fp16_local_param_num + 1]."); + "+ n_2, ...]. It should be in CPUPlace."); AddOutput( - "WeightDecay", - "The sharded fp32 weight decay tensor. Its shape is [(M1+M2)/N]."); + "FP32ShardFusedParamOffsets", + "The sharded numel offset of each parameter in the local rank. " + "Its shape is [fp32_local_param_num + 1]. It should be in CPUPlace."); + AddOutput( + "FP16ShardFusedParamOffsets", + "The sharded numel offset of each parameter in the local rank. " + "Its shape is [fp16_local_param_num + 1]. It should be in CPUPlace."); AddOutput("ParamInfo", "The param info. It should be in CPUPlace, and its shape is [6]" - "CPUPlace, and its shape is [6]. It is " + "CPUPlace, and its shape is [8]. It is " "[fp32_shard_param_start_idx, fp32_local_param_num, " - "fp32_global_param_num, fp16_shard_param_start_idx, " - "fp16_local_param_num, fp16_global_param_num]."); - + "fp32_global_param_num, fp32_weight_decay_end_idx, " + "fp16_shard_param_start_idx, " + "fp16_local_param_num, fp16_global_param_num, " + "fp16_weight_decay_end_idx]."); + AddOutput("ParamOrder", + "The reordered parameter order. Inside this op, " + "the parameter would be reordered by data type and weight decay " + "value."); AddOutput("ParamOut", "The output parameter list.").AsDuplicable(); AddOutput("MasterParamOut", "The output master parameter list. It would share the memory of " @@ -96,10 +97,8 @@ class DistributedFusedLambInitOpMaker AddAttr("beta1", "The initial value of Beta1Pow."); AddAttr("beta2", "The initial value of Beta2Pow."); - AddAttr>( - "weight_decay", - "The weight decay for each parameter. Its " - "shape is equal to the global parameter number."); + AddAttr>("apply_weight_decay", + "Whether to apply weight decay."); AddAttr("alignment", "The alignment in bytes for the fused tensors."); AddAttr("rank", "The global rank of the current process."); AddAttr("nranks", "The global world size."); diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu index 3445e9b658b..7d8a7186d58 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu @@ -258,32 +258,6 @@ static void ShareBufferForNonInitedTensor(framework::Tensor *origin, << ") , dtype = " << fused_out->dtype(); } -template -static __global__ void LambFillFusedIndicesCUDAKernel(const OffsetT *offsets, - IndexT *out, - int offset_num, - int out_num) { - CUDA_KERNEL_LOOP_TYPE(i, out_num, int) { - auto idx = phi::funcs::LowerBound(offsets, offset_num, i); - if (idx == offset_num || offsets[idx] != i) { - --idx; - } - out[i] = idx; - } -} - -template -static void CopyVectorToTensor(const std::vector &src, - framework::Tensor *dst, - const platform::Place &place, - gpuStream_t stream) { - dst->Resize({static_cast(src.size())}); - T *dst_ptr = dst->mutable_data(place); - const T *src_ptr = src.data(); - auto nbytes = src.size() * sizeof(T); - memory::Copy(place, dst_ptr, platform::CPUPlace(), src_ptr, nbytes, stream); -} - template static void CopyVectorToCPUTensor(const std::vector &src, framework::Tensor *dst) { @@ -294,6 +268,42 @@ static void CopyVectorToCPUTensor(const std::vector &src, std::memcpy(dst_ptr, src_ptr, nbytes); } +static size_t ReorderParamGradInfoList(const std::vector &flags, + std::vector *infos) { + size_t n = infos->size(); + std::vector cur_flags; + cur_flags.reserve(n); + for (size_t i = 0; i < n; ++i) { + auto idx = (*infos)[i].idx; + cur_flags.push_back(flags[idx]); + } + + auto origin_infos = *infos; + size_t j = 0; + for (size_t i = 0; i < n; ++i) { + if (cur_flags[i]) { + (*infos)[j] = origin_infos[i]; + ++j; + } + } + size_t ret_idx = j; + + for (size_t i = 0; i < n; ++i) { + if (!cur_flags[i]) { + (*infos)[j] = origin_infos[i]; + ++j; + } + } + return ret_idx; +} + +template +static T ClipByBound(T x, T low_value, T high_value) { + if (x < low_value) return low_value; + if (x > high_value) return high_value; + return x; +} + template class DistributedFusedLambInitOpKernel : public framework::OpKernel { @@ -404,6 +414,24 @@ class DistributedFusedLambInitOpKernel info->numel_offset = 0; // not determined yet } } + const auto &apply_weight_decay = + ctx.Attr>("apply_weight_decay"); + size_t fp32_wd_end_idx = + ReorderParamGradInfoList(apply_weight_decay, &fp32_infos); + size_t fp16_wd_end_idx = + ReorderParamGradInfoList(apply_weight_decay, &fp16_infos); + + auto *param_order_t = ctx.Output("ParamOrder"); + auto param_num = fp32_infos.size() + fp16_infos.size(); + param_order_t->Resize({static_cast(param_num)}); + auto *param_order = param_order_t->mutable_data(platform::CPUPlace()); + for (size_t i = 0; i < fp32_infos.size(); ++i) { + param_order[i] = static_cast(fp32_infos[i].idx); + } + for (size_t i = 0; i < fp16_infos.size(); ++i) { + param_order[i + fp32_infos.size()] = static_cast(fp16_infos[i].idx); + } + VLOG(10) << "Fill ParamGradInfo ends"; // Step 2: determine the numel_with_padding and numel_offset @@ -568,45 +596,29 @@ class DistributedFusedLambInitOpKernel VLOG(10) << "Found the sharding arguments"; auto *param_info_t = ctx.Output("ParamInfo"); - param_info_t->Resize({6}); + param_info_t->Resize({8}); auto *param_info = param_info_t->mutable_data(platform::CPUPlace()); param_info[0] = static_cast(fp32_start_idx); param_info[1] = static_cast(fp32_local_param_num); param_info[2] = static_cast(fp32_infos.size()); - param_info[3] = static_cast(fp16_start_idx + fp32_infos.size()); - param_info[4] = static_cast(fp16_local_param_num); - param_info[5] = static_cast(fp16_infos.size()); + param_info[3] = ClipByBound(fp32_wd_end_idx, fp32_start_idx, + fp32_start_idx + fp32_local_param_num) - + static_cast(fp32_start_idx); + param_info[4] = static_cast(fp16_start_idx + fp32_infos.size()); + param_info[5] = static_cast(fp16_local_param_num); + param_info[6] = static_cast(fp16_infos.size()); + param_info[7] = ClipByBound(fp16_wd_end_idx, fp16_start_idx, + fp16_start_idx + fp16_local_param_num) - + static_cast(fp16_start_idx); VLOG(10) << "Start FP32 idx: " << param_info[0]; VLOG(10) << "Local FP32 param num: " << param_info[1]; VLOG(10) << "Global FP32 param num: " << param_info[2]; - VLOG(10) << "Start FP16 idx: " << param_info[3]; - VLOG(10) << "Local FP16 param num: " << param_info[4]; - VLOG(10) << "Global FP16 param num: " << param_info[5]; + VLOG(10) << "Start FP16 idx: " << param_info[4]; + VLOG(10) << "Local FP16 param num: " << param_info[5]; + VLOG(10) << "Global FP16 param num: " << param_info[6]; - // For WeightDecay, shard and perform H2D copy - const auto &origin_weight_decay = - ctx.Attr>("weight_decay"); - PADDLE_ENFORCE_EQ(params.size(), origin_weight_decay.size(), - platform::errors::InvalidArgument( - "The attr(weight_decay) should have the " - "same length with Input(Param).")); - std::vector shard_weight_decay; - shard_weight_decay.reserve(total_local_param_num); - for (size_t i = 0; i < fp32_local_param_num; ++i) { - shard_weight_decay.push_back( - origin_weight_decay[fp32_infos[i + fp32_start_idx].idx]); - } - for (size_t i = 0; i < fp16_local_param_num; ++i) { - shard_weight_decay.push_back( - origin_weight_decay[fp16_infos[i + fp16_start_idx].idx]); - } - - // For FusedIndices, launch CUDA kernel to do binary search - auto *fused_indices_t = ctx.Output("FusedIndices"); - fused_indices_t->Resize({static_cast(total_numel)}); - auto *fused_indices = fused_indices_t->mutable_data(place); std::vector numel_offsets; numel_offsets.reserve(params.size() + 1); for (const auto &info : fp32_infos) { @@ -621,21 +633,6 @@ class DistributedFusedLambInitOpKernel "The numel_offsets number must be one larger than " "the parameter number.")); VLOG(10) << "Total numel offset: " << FlattenToString(numel_offsets); - auto *fused_param_offset_t = - ctx.Output("FusedParamOffsets"); - fused_param_offset_t->Resize({static_cast(numel_offsets.size())}); - auto *fused_param_offset = fused_param_offset_t->mutable_data(place); - memory::Copy(place, fused_param_offset, platform::CPUPlace(), - numel_offsets.data(), - numel_offsets.size() * sizeof(numel_offsets[0]), stream); - auto config = platform::GetGpuLaunchConfig1D(dev_ctx, total_numel); - LambFillFusedIndicesCUDAKernel<<>>( - fused_param_offset, fused_indices, numel_offsets.size() - 1, - total_numel); - - std::vector lengths; - lengths.reserve(fp32_local_param_num + fp16_local_param_num); std::vector fp32_partial_numel_offsets; fp32_partial_numel_offsets.reserve(fp32_local_param_num + 1); @@ -659,9 +656,9 @@ class DistributedFusedLambInitOpKernel VLOG(10) << "FP32 Partial numel = [" << valid_start_n + fp32_infos[i].numel << "," << end_n + fp32_infos[i].numel; - lengths.push_back(end_n - valid_start_n); + auto len = end_n - valid_start_n; fp32_partial_numel_offsets.push_back(fp32_partial_numel_offsets.back() + - lengths.back()); + len); } std::vector fp16_partial_numel_offsets; @@ -682,9 +679,9 @@ class DistributedFusedLambInitOpKernel PADDLE_ENFORCE_NE(valid_start_n, end_n, platform::errors::InvalidArgument( "Indices sharding error. This may be a bug.")); - lengths.push_back(end_n - valid_start_n); + auto len = end_n - valid_start_n; fp16_partial_numel_offsets.push_back(fp16_partial_numel_offsets.back() + - lengths.back()); + len); } CopyVectorToCPUTensor(numel_offsets, @@ -696,23 +693,6 @@ class DistributedFusedLambInitOpKernel fp16_partial_numel_offsets, ctx.Output("FP16ShardFusedParamOffsets")); - // Fill the weight decay tensor - PADDLE_ENFORCE_EQ(lengths.size(), shard_weight_decay.size(), - platform::errors::InvalidArgument( - "Invalid weight decay sharding. This may be a bug.")); - std::vector wd_cpu; - for (size_t i = 0; i < shard_weight_decay.size(); ++i) { - int len = lengths[i]; - for (int j = 0; j < len; ++j) { - wd_cpu.push_back(shard_weight_decay[i]); - } - } - PADDLE_ENFORCE_EQ(wd_cpu.size() * nranks, fp32_numel + fp16_numel, - platform::errors::InvalidArgument( - "Invalid weight decay sharding. This may be a bug.")); - CopyVectorToTensor(wd_cpu, ctx.Output("WeightDecay"), - place, stream); - auto *global_scale = ctx.Output("GlobalScale"); if (!global_scale->IsInitialized()) { TensorFillConstant(dev_ctx, global_scale, {1}, 1.0f); diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc index e5b27446eb3..8f7c87912e9 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc @@ -66,28 +66,31 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker { "The fp32 beta1 power accumulator tensor. Its shape is [1]."); AddInput("Beta2Pow", "The fp32 beta2 power accumulator tensor. Its shape is [1]."); - AddInput("FusedIndices", - "The param index of each element in FP32FusedParam. Its shape is " - "[M1+M2]. It is like [0,0,0,1,1,1,1,2,2,...]."); AddInput( "FusedParamOffsets", "The numel offset of each parameter inside the FP32FusedParam. Its " "shape is [param_num + 1]. It is like [0, n_0, n_0 + n_1, n_0 + n_1 " - "+ n_2, ...]."); - AddInput("FP32ShardFusedParamOffsets", - "The sharded numel offset of each parameter in the local rank. " - "Its shape is [fp32_local_param_num + 1]."); - AddInput("FP16ShardFusedParamOffsets", - "The sharded numel offset of each parameter in the local rank. " - "Its shape is [fp16_local_param_num + 1]."); - AddInput("WeightDecay", - "The sharded fp32 weight decay tensor. Its shape is [(M1+M2)/N]."); + "+ n_2, ...]. It should be in CPUPlace."); + AddInput( + "FP32ShardFusedParamOffsets", + "The sharded numel offset of each parameter in the local rank. " + "Its shape is [fp32_local_param_num + 1]. It should be in CPUPlace."); + AddInput( + "FP16ShardFusedParamOffsets", + "The sharded numel offset of each parameter in the local rank. " + "Its shape is [fp16_local_param_num + 1]. It should be in CPUPlace."); AddInput("ParamInfo", "The param info. It should be in CPUPlace, and its shape is [6]" - "CPUPlace, and its shape is [6]. It is " + "CPUPlace, and its shape is [8]. It is " "[fp32_shard_param_start_idx, fp32_local_param_num, " - "fp32_global_param_num, fp16_shard_param_start_idx, " - "fp16_local_param_num, fp16_global_param_num]."); + "fp32_global_param_num, fp32_weight_decay_end_idx, " + "fp16_shard_param_start_idx, " + "fp16_local_param_num, fp16_global_param_num, " + "fp16_weight_decay_end_idx]."); + AddInput("ParamOrder", + "The reordered parameter order. Inside this op, " + "the parameter would be reordered by data type and weight decay " + "value."); AddInput("LearningRate", "The fp32 learning rate tensor. Its shape is [1]."); @@ -116,6 +119,7 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker { "max_global_grad_norm", "The maximum global gradient l2-norm value for clipping. If " "max_global_grad_norm <= 0, no clipping would be performed."); + AddAttr("weight_decay", "The weight decay value."); AddAttr("clip_after_allreduce", "Whether to clip before allreduce, only valid when the " "world size is larger than 1."); diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu index 3f90140f772..ca0828a6f6a 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu @@ -87,7 +87,7 @@ struct L2NormFunctor { } }; -template +template static __global__ void MultiTensorL2NormReduceAgainCUDAKernel( const InT *x, OutT *y, int max_chunk_num) { int tensor_id = blockIdx.x; @@ -100,11 +100,7 @@ static __global__ void MultiTensorL2NormReduceAgainCUDAKernel( } sum = BlockReduce(storage).Reduce(sum, cub::Sum()); if (threadIdx.x == 0) { - if (NeedSqrt) { - y[blockIdx.x] = static_cast(sqrtf(sum)); - } else { - y[blockIdx.x] = static_cast(sum); - } + y[blockIdx.x] = static_cast(sum); } } @@ -118,6 +114,7 @@ static int GetChunkedVecSize(const T *ptr, int chunk_size) { constexpr int vec8 = alignof(platform::AlignedVector); constexpr int vec4 = alignof(platform::AlignedVector); constexpr int vec2 = alignof(platform::AlignedVector); + chunk_size *= sizeof(T); if (address % vec8 == 0 && chunk_size % vec8 == 0) { return std::min(8, valid_vec_size); } else if (address % vec4 == 0 && chunk_size % vec4 == 0) { @@ -129,27 +126,26 @@ static int GetChunkedVecSize(const T *ptr, int chunk_size) { } } -#define PD_VEC_MULTI_TENSOR_APPLY_CASE(__vec_size, ...) \ - case __vec_size: { \ - constexpr int kVecSize = __vec_size; \ - __VA_ARGS__; \ - break; \ +#define PD_VEC_LAUNCH_KERNEL_CASE(__vec_size, ...) \ + case __vec_size: { \ + constexpr int kVecSize = __vec_size; \ + __VA_ARGS__; \ + break; \ } -#define PD_VEC_MULTI_TENSOR_APPLY(__vec_size, ...) \ - do { \ - switch (__vec_size) { \ - PD_VEC_MULTI_TENSOR_APPLY_CASE(8, __VA_ARGS__); \ - PD_VEC_MULTI_TENSOR_APPLY_CASE(4, __VA_ARGS__); \ - PD_VEC_MULTI_TENSOR_APPLY_CASE(2, __VA_ARGS__); \ - PD_VEC_MULTI_TENSOR_APPLY_CASE(1, __VA_ARGS__); \ - } \ +#define PD_VEC_LAUNCH_KERNEL(__vec_size, ...) \ + do { \ + switch (__vec_size) { \ + PD_VEC_LAUNCH_KERNEL_CASE(8, __VA_ARGS__); \ + PD_VEC_LAUNCH_KERNEL_CASE(4, __VA_ARGS__); \ + PD_VEC_LAUNCH_KERNEL_CASE(2, __VA_ARGS__); \ + PD_VEC_LAUNCH_KERNEL_CASE(1, __VA_ARGS__); \ + } \ } while (0) // TODO(zengjinle): which chunk_size is better? -template +template static void MultiTensorL2Norm(const platform::CUDAPlace &place, gpuStream_t stream, const InT *x, const int *offsets, int n, OutT *y, @@ -158,7 +154,7 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place, constexpr int kNumTensor = MaxTensorNumPerLaunch; constexpr int kNumChunk = MaxChunkNumPerLaunch; - constexpr int kBlockDim = BlockDim; + constexpr int kBlockDim = 512; int max_chunk_num = -1; int vec_size = 8; @@ -181,22 +177,22 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place, auto *tmp_out_ptr = tmp_out.Alloc(n * max_chunk_num); FillZeroWithPtr(tmp_out_ptr, n * max_chunk_num, stream); -#define PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL \ - do { \ - using FunctorT = L2NormFunctor; \ - VLOG(10) << __func__ << " " << typeid(InT).name() \ - << " VecSize = " << kVecSize; \ - MultiTensorApply( \ - FunctorT(), stream, offsets, n, chunk_size, x, tmp_out_ptr, \ - max_chunk_num); \ +#define PD_LAUNCH_MULTI_TENSOR_APPLY_L2_NORM_KERNEL \ + do { \ + using FunctorT = L2NormFunctor; \ + VLOG(10) << __func__ << " " << typeid(InT).name() \ + << " VecSize = " << kVecSize; \ + MultiTensorApply( \ + FunctorT(), stream, offsets, n, chunk_size, kBlockDim, x, tmp_out_ptr, \ + max_chunk_num); \ } while (0) - PD_VEC_MULTI_TENSOR_APPLY(vec_size, PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL); -#undef PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL + PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAUNCH_MULTI_TENSOR_APPLY_L2_NORM_KERNEL); +#undef PD_LAUNCH_MULTI_TENSOR_APPLY_L2_NORM_KERNEL - MultiTensorL2NormReduceAgainCUDAKernel<<>>( - tmp_out_ptr, y, max_chunk_num); + MultiTensorL2NormReduceAgainCUDAKernel< + MT, OutT, kBlockDim><<>>(tmp_out_ptr, y, + max_chunk_num); } template @@ -208,34 +204,17 @@ static void LogParamAndTrustRatioDivSquareNorm( auto tensors = ctx.MultiInput("Param"); if (tensors.empty()) return; + const auto *order = ctx.Input("ParamOrder")->data(); + size_t n = tensors.size(); auto place = tensors[0]->place(); auto pn_vec = ToVector(param_square_norm, n, place); auto tn_vec = ToVector(trust_ratio_div_square_norm, n, place); - std::vector fp32_indices, fp16_indices; - fp32_indices.reserve(n); - fp16_indices.reserve(n); - for (size_t i = 0; i < n; ++i) { - const auto *t = tensors[i]; - if (t->dtype() == phi::DataType::FLOAT32) { - fp32_indices.push_back(i); - } else if (t->dtype() == phi::DataType::FLOAT16) { - fp16_indices.push_back(i); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupported data type %s.", t->dtype())); - } - } - - for (auto idx : fp16_indices) { - fp32_indices.push_back(idx); - } - const auto &names = ctx.GetOp().Inputs("Param"); - for (size_t i = 0; i < fp32_indices.size(); ++i) { - auto idx = fp32_indices[i]; + for (size_t i = 0; i < n; ++i) { + auto idx = order[i]; VLOG(LogLevel) << "Param " << tensors[idx]->dtype() << " " << names[idx] << " pn = " << pn_vec[i] << " , tn = " << tn_vec[i]; } @@ -353,7 +332,7 @@ static __global__ void CalcGradNormClipBeforeAllReduceScale( const T1 *__restrict__ global_scale, T1 max_global_grad_norm, const T1 *__restrict__ square_grad_norm, T1 *__restrict__ out1, T2 *__restrict__ out2, T1 clip_rescale_grad) { - T1 grad_norm = static_cast(sqrt(*square_grad_norm)) * clip_rescale_grad; + T1 grad_norm = static_cast(sqrtf(*square_grad_norm)) * clip_rescale_grad; T1 scale = global_scale[0] * max_global_grad_norm / (1e-6 + grad_norm); bool found_nan_inf = !isfinite(scale); if (scale >= 1 || found_nan_inf) { @@ -380,19 +359,24 @@ static __global__ void SetNanInfValueCUDAKernelTwoFlag(const bool *in_flag_p_1, ((*in_flag_p_1) || (*in_flag_p_2)) ? __int_as_float(0x7fffffffU) : 0.0f; } -// TODO(zengjinle): Vectorize this function -// NOTE: this method does not update Beta1Pow and Beta2Pow! -template -static __global__ void UpdateLambMoment( +template +static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel( const T *__restrict__ param_p, const GradT *__restrict__ grad_p, const T *__restrict__ square_grad_norm_p, - const T *__restrict__ global_scale, const IndexT *__restrict__ indices, - const T *__restrict__ weight_decay_p, const T *__restrict__ beta1pow_p, + const T *__restrict__ global_scale, const T *__restrict__ beta1pow_p, const T *__restrict__ beta2pow_p, T *__restrict__ mom1_p, - T *__restrict__ mom2_p, T *__restrict__ trust_ratio_div_p, T beta1, T beta2, - T epsilon, T max_global_grad_norm, int num, T rescale_grad) { + T *__restrict__ mom2_p, T *__restrict__ trust_ratio_div_p, bool *found_inf, + T weight_decay, int weight_decay_end_numel, T beta1, T beta2, T epsilon, + T max_global_grad_norm, int num, T rescale_grad) { T square_grad_norm = *square_grad_norm_p; - if (!isfinite(square_grad_norm)) return; + bool need_update_found_inf = + (found_inf && threadIdx.x == 0 && blockIdx.x == 0); + if (!isfinite(square_grad_norm)) { + if (need_update_found_inf) *found_inf = true; + return; + } else if (need_update_found_inf) { + *found_inf = false; + } T scale = rescale_grad / global_scale[0]; if (max_global_grad_norm > 0) { @@ -406,27 +390,112 @@ static __global__ void UpdateLambMoment( T one_minus_beta1pow = 1 - beta1pow_p[0]; T one_minus_beta2pow = 1 - beta2pow_p[0]; - CUDA_KERNEL_LOOP(i, num) { - T p = param_p[i]; - T g = static_cast(grad_p[i]) * scale; - T weight_decay = weight_decay_p[i]; - T mom1 = mom1_p[i]; - T mom2 = mom2_p[i]; - - mom1 = beta1 * mom1 + (1 - beta1) * g; - mom2 = beta2 * mom2 + (1 - beta2) * g * g; - - T mom1_unbiased = mom1 / one_minus_beta1pow; - T mom2_unbiased = mom2 / one_minus_beta2pow; - T trust_ratio_div = - mom1_unbiased / (sqrtf(mom2_unbiased) + epsilon) + weight_decay * p; - - mom1_p[i] = mom1; - mom2_p[i] = mom2; - trust_ratio_div_p[i] = trust_ratio_div; + int i = (threadIdx.x + blockIdx.x * blockDim.x) * VecSize; + int stride = blockDim.x * gridDim.x * VecSize; + + for (; i + VecSize <= num; i += stride) { + platform::AlignedVector param_vec; + platform::AlignedVector grad_vec; + platform::AlignedVector weight_decay_vec; + platform::AlignedVector mom1_vec; + platform::AlignedVector mom2_vec; + platform::AlignedVector trust_ratio_div_vec; + + T cur_weight_decay = (i < weight_decay_end_numel) * weight_decay; + if (cur_weight_decay != static_cast(0.0)) { + platform::Load(param_p + i, ¶m_vec); + } else { +#pragma unroll + for (int j = 0; j < VecSize; ++j) { + param_vec[j] = static_cast(0); + } + } + platform::Load(grad_p + i, &grad_vec); + platform::Load(mom1_p + i, &mom1_vec); + platform::Load(mom2_p + i, &mom2_vec); + +#define PD_LAMB_MOM_TRUST_RATIO_DIV_UPDATE(__param, __grad, __mom1, __mom2, \ + __trust_ratio_div, __idx) \ + T p = __param[__idx]; \ + T g = static_cast(__grad[__idx]) * scale; \ + T mom1 = __mom1[__idx]; \ + T mom2 = __mom2[__idx]; \ + mom1 = beta1 * mom1 + (1 - beta1) * g; \ + mom2 = beta2 * mom2 + (1 - beta2) * g * g; \ + T mom1_unbiased = mom1 / one_minus_beta1pow; \ + T mom2_unbiased = mom2 / one_minus_beta2pow; \ + __trust_ratio_div[__idx] = \ + mom1_unbiased / (sqrtf(mom2_unbiased) + epsilon) + cur_weight_decay * p; \ + __mom1[__idx] = mom1; \ + __mom2[__idx] = mom2; + +#pragma unroll + for (int j = 0; j < VecSize; ++j) { + PD_LAMB_MOM_TRUST_RATIO_DIV_UPDATE(param_vec, grad_vec, mom1_vec, + mom2_vec, trust_ratio_div_vec, j); + } + + platform::Store(mom1_vec, mom1_p + i); + platform::Store(mom2_vec, mom2_p + i); + platform::Store(trust_ratio_div_vec, trust_ratio_div_p + i); + } + + for (; i < num; ++i) { + T cur_weight_decay = (i < weight_decay_end_numel) * weight_decay; + PD_LAMB_MOM_TRUST_RATIO_DIV_UPDATE(param_p, grad_p, mom1_p, mom2_p, + trust_ratio_div_p, i); } } +template +static void MultiTensorUpdateLambMomentAndTrustRatioDiv( + const platform::CUDADeviceContext &dev_ctx, const int *offsets, int n, + const T *param_p, const GradT *grad_p, const T *square_grad_norm_p, + const T *global_scale, const T *beta1pow_p, const T *beta2pow_p, T *mom1_p, + T *mom2_p, T *trust_ratio_div_p, bool *found_inf_p, T weight_decay, + int weight_decay_end_idx, T beta1, T beta2, T epsilon, + T max_global_grad_norm, T rescale_grad) { + if (n <= 0) return; + int numel = offsets[n] - offsets[0]; + PADDLE_ENFORCE_GE(weight_decay_end_idx, 0, + platform::errors::InvalidArgument( + "The weight decay end index should be >= 0.")); + PADDLE_ENFORCE_LE(weight_decay_end_idx, n, + platform::errors::InvalidArgument( + "The weight decay end index should be < %d.", n)); + auto weight_decay_end_numel = offsets[weight_decay_end_idx] - offsets[0]; + + int vec_size = GetChunkedVecSize(param_p, 0); + vec_size = std::min(vec_size, GetChunkedVecSize(grad_p, 0)); + vec_size = std::min(vec_size, GetChunkedVecSize(mom1_p, 0)); + vec_size = std::min(vec_size, GetChunkedVecSize(mom2_p, 0)); + vec_size = std::min(vec_size, GetChunkedVecSize(trust_ratio_div_p, 0)); + for (int i = 0; i < n; ++i) { + auto length = offsets[i + 1] - offsets[i]; + while (length % vec_size != 0) { + vec_size /= 2; + } + } + + VLOG(1) << __func__ << " VecSize = " << vec_size; + + auto stream = dev_ctx.stream(); + auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size); + +#define PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL \ + do { \ + UpdateLambMomentAndTrustRatioDivCUDAKernel<<< \ + config.block_per_grid, config.thread_per_block, 0, stream>>>( \ + param_p, grad_p, square_grad_norm_p, global_scale, beta1pow_p, \ + beta2pow_p, mom1_p, mom2_p, trust_ratio_div_p, found_inf_p, \ + weight_decay, weight_decay_end_numel, beta1, beta2, epsilon, \ + max_global_grad_norm, numel, rescale_grad); \ + } while (0) + + PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL); +#undef PD_LAUNCH_LAMB_MOM_TRUST_RATIO_DIV_KERNEL +} + template struct LambBetaPowUpdateOnceHelper { LambBetaPowUpdateOnceHelper(T *beta1pow, T *beta2pow, T beta1, T beta2) { @@ -468,33 +537,6 @@ struct LambBetaPowUpdateOnceHelper { HOSTDEVICE void UpdateBetaPows() const {} }; -template -struct LambFoundInfHelper { - public: - explicit LambFoundInfHelper(bool *found_inf) : found_inf_(found_inf) { - PADDLE_ENFORCE_NOT_NULL(found_inf, - platform::errors::InvalidArgument( - "The found_inf should not be nullptr.")); - } - - HOSTDEVICE void UpdateFoundInf(bool value) { *found_inf_ = value; } - - private: - bool *__restrict__ found_inf_; -}; - -template <> -struct LambFoundInfHelper { - public: - explicit LambFoundInfHelper(bool *found_inf) { - PADDLE_ENFORCE_EQ( - found_inf, nullptr, - platform::errors::InvalidArgument("The found_inf should be nullptr.")); - } - - HOSTDEVICE void UpdateFoundInf(bool) {} -}; - template struct LambParamHelper { LambParamHelper(T *param, MasterT *master_param) { @@ -509,12 +551,9 @@ struct LambParamHelper { master_param_ = master_param; } - HOSTDEVICE void SetParam(int i, MasterT updated_p) { - param_[i] = static_cast(updated_p); - master_param_[i] = updated_p; - } + HOSTDEVICE T *__restrict__ ParamPtr() { return param_; } - HOSTDEVICE MasterT GetParam(int i) { return master_param_[i]; } + HOSTDEVICE MasterT *__restrict__ MasterParamPtr() { return master_param_; } private: T *__restrict__ param_; @@ -538,158 +577,169 @@ struct LambParamHelper { param_ = param; } - HOSTDEVICE void SetParam(int i, MasterT updated_p) { - param_[i] = static_cast(updated_p); - } + HOSTDEVICE T *__restrict__ ParamPtr() { return param_; } - HOSTDEVICE MasterT GetParam(int i) { - return static_cast>(param_[i]); - } + HOSTDEVICE constexpr MasterT *MasterParamPtr() { return nullptr; } private: T *__restrict__ param_; }; -template -struct LambParamAndBetaPowsUpdateHelper - : public LambParamHelper, - public LambBetaPowUpdateOnceHelper, NeedUpdateBetaPow>, - public LambFoundInfHelper { - LambParamAndBetaPowsUpdateHelper( - ParamT *param, MasterT *master_param, MasterT *beta1pow, - MasterT *beta2pow, MasterT beta1, MasterT beta2, - bool *found_inf, const MasterT *trust_ratio_div, - const MasterT *lr, const IndexT *index, +template +struct LambUpdateParamAndBetaPowsFunctor { + DEVICE void operator()( + int tensor_id, int chunk_id, int offset, int size, + LambParamHelper param_helper, + const MasterT *trust_ratio_div, const MasterT *lr, const MasterT *param_square_norm, - const MasterT *trust_ratio_div_square_norm, - const MasterT *update_flag) - : LambParamHelper(param, master_param), - LambBetaPowUpdateOnceHelper, NeedUpdateBetaPow>( - beta1pow, beta2pow, beta1, beta2), - LambFoundInfHelper(found_inf), - trust_ratio_div(trust_ratio_div), - lr(lr), - index(index), - param_square_norm(param_square_norm), - trust_ratio_div_square_norm(trust_ratio_div_square_norm), - update_flag(update_flag) {} - - const MasterT *__restrict__ trust_ratio_div; - const MasterT *__restrict__ lr; - const IndexT *__restrict__ index; - const MasterT *__restrict__ param_square_norm; - const MasterT *__restrict__ trust_ratio_div_square_norm; - const MasterT *__restrict__ update_flag; -}; + const MasterT *trust_ratio_div_square_norm, const bool *found_inf, + LambBetaPowUpdateOnceHelper, NeedUpdateBetaPow> + betapow_helper) const { + if (*found_inf) return; + + using MT = MasterT; -template -static __global__ void LambUpdateParamAndBetaPowsCUDAKernel( - LambParamAndBetaPowsUpdateHelper - args, - int num) { - auto should_update = *args.update_flag; - if (!isfinite(should_update)) { - if (HasFoundInf && threadIdx.x == 0 && blockIdx.x == 0) { - args.UpdateFoundInf(true); + MT p_square_norm = param_square_norm[tensor_id]; + MT t_square_norm = trust_ratio_div_square_norm[tensor_id]; + MT lr_value = *lr; + MT ratio = (p_square_norm != static_cast(0) && + t_square_norm != static_cast(0) + ? lr_value * sqrtf(p_square_norm / t_square_norm) + : lr_value); + + int i; + int stride = blockDim.x * VecSize; + + ParamT *param = param_helper.ParamPtr() + offset; + MT *master_param = HasMasterParam ? param_helper.MasterParamPtr() + offset + : param_helper.MasterParamPtr(); + trust_ratio_div += offset; + + for (i = threadIdx.x * VecSize; i + VecSize <= size; i += stride) { + platform::AlignedVector trust_ratio_div_vec; + platform::Load(trust_ratio_div + i, &trust_ratio_div_vec); + if (HasMasterParam) { + platform::AlignedVector master_param_vec; + platform::Load(master_param + i, &master_param_vec); + platform::AlignedVector param_vec; +#pragma unroll + for (int j = 0; j < VecSize; ++j) { + MT p = master_param_vec[j] - ratio * trust_ratio_div_vec[j]; + master_param_vec[j] = p; + param_vec[j] = static_cast(p); + } + platform::Store(master_param_vec, master_param + i); + platform::Store(param_vec, param + i); + } else { + platform::AlignedVector param_vec; + platform::Load(param + i, ¶m_vec); +#pragma unroll + for (int j = 0; j < VecSize; ++j) { + MT p = static_cast(param_vec[j]) - ratio * trust_ratio_div_vec[j]; + param_vec[j] = static_cast(p); + } + platform::Store(param_vec, param + i); + } + } + + for (; i < size; ++i) { + if (HasMasterParam) { + MT p = master_param[i] - ratio * trust_ratio_div[i]; + master_param[i] = p; + param[i] = static_cast(p); + } else { + MT p = static_cast(param[i]) - ratio * trust_ratio_div[i]; + param[i] = static_cast(p); + } + } + + if (NeedUpdateBetaPow && threadIdx.x == 0 && blockIdx.x == 0) { + betapow_helper.UpdateBetaPows(); } - return; - } else if (HasFoundInf && threadIdx.x == 0 && blockIdx.x == 0) { - args.UpdateFoundInf(false); } +}; - if (NeedUpdateBetaPow && threadIdx.x == 0 && blockIdx.x == 0) { - args.UpdateBetaPows(); +// TODO(zengjinle): which block_dim and chunk_size would be better? +template +static void MultiTensorUpdateLambParamAndBetaPows( + const platform::CUDADeviceContext &dev_ctx, const int *offsets, int n, + const MasterT *trust_ratio_div, const MasterT *lr, + const MasterT *param_square_norm, + const MasterT *trust_ratio_div_square_norm, const bool *found_inf, + ParamT *param, MasterT *master_param, MasterT *beta1pow, + MasterT *beta2pow, MasterT beta1, MasterT beta2, + int chunk_size = 65536) { + constexpr bool kHasMasterParam = + !(std::is_same>::value); + + bool has_beta_pow = (beta1pow != nullptr); + if (has_beta_pow) { + PADDLE_ENFORCE_NOT_NULL(beta2pow, platform::errors::InvalidArgument( + "Beta2Pow should not be nullptr.")); + } else { + PADDLE_ENFORCE_EQ(beta2pow, nullptr, platform::errors::InvalidArgument( + "Beta2Pow should be nullptr.")); } - using MT = MasterT; + const int block_dim = 512; - MT lr_value = *args.lr; - CUDA_KERNEL_LOOP(i, num) { - MT p = args.GetParam(i); - MT t = args.trust_ratio_div[i]; - auto norm_idx = args.index[i]; - MT p_square_norm = args.param_square_norm[norm_idx]; - MT t_square_norm = args.trust_ratio_div_square_norm[norm_idx]; + int vec_size = 8; + for (int i = 0; i < n; ++i) { + int offset = offsets[i] - offsets[0]; + vec_size = + std::min(vec_size, GetChunkedVecSize(param + offset, chunk_size)); + if (kHasMasterParam) { + vec_size = std::min(vec_size, + GetChunkedVecSize(master_param + offset, chunk_size)); + } + vec_size = std::min( + vec_size, GetChunkedVecSize(trust_ratio_div + offset, chunk_size)); + } - MT p_norm = static_cast(sqrtf(p_square_norm)); - MT t_norm = static_cast(sqrtf(t_square_norm)); + VLOG(1) << __func__ << " VecSize = " << vec_size; - auto update = (p_norm != static_cast(0) && t_norm != static_cast(0)) - ? p_norm / t_norm - : static_cast(1); + constexpr auto kNumTensor = MaxTensorNumPerLaunch; + constexpr auto kNumChunk = MaxChunkNumPerLaunch; - MT updated_p = p - lr_value * update * t; - args.SetParam(i, updated_p); - } -} + auto stream = dev_ctx.stream(); +#define PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(__has_beta_pow) \ + do { \ + using FunctorT = \ + LambUpdateParamAndBetaPowsFunctor; \ + LambParamHelper param_helper(param, \ + master_param); \ + LambBetaPowUpdateOnceHelper, __has_beta_pow> \ + betapow_helper(beta1pow, beta2pow, beta1, beta2); \ + launcher.Launch(FunctorT(), param_helper, trust_ratio_div, lr, \ + param_square_norm, trust_ratio_div_square_norm, found_inf, \ + betapow_helper); \ + } while (0) -template -static void LambUpdateParamAndBetaPows( - const platform::CUDADeviceContext &dev_ctx, - const MasterT *trust_ratio_div, const MasterT *lr, - const IndexT *index, const MasterT *param_square_norm, - const MasterT *trust_ratio_div_square_norm, - const MasterT *update_flag, MasterT **beta1pow, - MasterT **beta2pow, bool **found_inf, MasterT beta1, - MasterT beta2, int num, ParamT *param, - MasterT *master_param, gpuStream_t stream) { - if (num == 0) return; - - bool has_master_param = !(std::is_same>::value); - auto has_beta_pow = (*beta1pow) != nullptr && (*beta2pow) != nullptr; - auto has_found_inf = (*found_inf) != nullptr; - -#define PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL( \ - __has_master_param, __has_beta_pow, __has_found_inf) \ - do { \ - LambParamAndBetaPowsUpdateHelper \ - helper(param, master_param, *beta1pow, *beta2pow, beta1, beta2, \ - *found_inf, trust_ratio_div, lr, index, param_square_norm, \ - trust_ratio_div_square_norm, update_flag); \ - auto config = platform::GetGpuLaunchConfig1D(dev_ctx, num); \ - LambUpdateParamAndBetaPowsCUDAKernel<<< \ - config.block_per_grid, config.thread_per_block, 0, stream>>>(helper, \ - num); \ +#define PD_LAUNCH_VEC_MULTI_TENSOR_UPDATE_PARAM_BETAPOW_CASE \ + do { \ + auto callback = [&]( \ + const MultiTensorLauncher &launcher, \ + int launch_n) { \ + if (has_beta_pow && launch_n == 0) { \ + PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(true); \ + beta1pow = nullptr; \ + beta2pow = nullptr; \ + } else { \ + PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW(false); \ + } \ + }; \ + MultiTensorApplyWithCallback( \ + stream, offsets, n, chunk_size, block_dim, callback); \ } while (0) - if (has_master_param) { - if (has_beta_pow) { - if (has_found_inf) { - PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(true, true, true); - } else { - PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(true, true, false); - } - } else { - if (has_found_inf) { - PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(true, false, true); - } else { - PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(true, false, false); - } - } - } else { - if (has_beta_pow) { - if (has_found_inf) { - PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(false, true, true); - } else { - PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(false, true, false); - } - } else { - if (has_found_inf) { - PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(false, false, true); - } else { - PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL(false, false, false); - } - } - } + PD_VEC_LAUNCH_KERNEL(vec_size, + PD_LAUNCH_VEC_MULTI_TENSOR_UPDATE_PARAM_BETAPOW_CASE); - *beta1pow = nullptr; - *beta2pow = nullptr; - *found_inf = nullptr; -#undef PADDLE_LAUNCH_LAMB_UPDATE_PARAM_KERNEL +#undef PD_LAUNCH_MULTI_TENSOR_UPDATE_PARAM_BETAPOW +#undef PD_LAUNCH_VEC_MULTI_TENSOR_UPDATE_PARAM_BETAPOW_CASE } #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) @@ -1005,15 +1055,16 @@ class DistributedFusedLambOpKernel "Too many parameter number. Only <= %d is supported.", std::numeric_limits::max())); - // Step 3: Get FusedIndices, ParamInfo - const auto *indices = GetInputTensorPtr(ctx, "FusedIndices"); + // Step 3: Get ParamInfo const auto *param_info_tensor = GetInputTensorPtr(ctx, "ParamInfo"); auto fp32_local_start_idx = param_info_tensor[0]; auto fp32_local_param_num = param_info_tensor[1]; auto fp32_global_param_num = param_info_tensor[2]; - auto fp16_local_start_idx = param_info_tensor[3]; - auto fp16_local_param_num = param_info_tensor[4]; - auto fp16_global_param_num = param_info_tensor[5]; + auto fp32_weight_decay_end_idx = param_info_tensor[3]; + auto fp16_local_start_idx = param_info_tensor[4]; + auto fp16_local_param_num = param_info_tensor[5]; + auto fp16_global_param_num = param_info_tensor[6]; + auto fp16_weight_decay_end_idx = param_info_tensor[7]; auto local_param_num = fp32_local_param_num + fp16_local_param_num; auto param_num = fp32_global_param_num + fp16_global_param_num; @@ -1031,7 +1082,7 @@ class DistributedFusedLambOpKernel << " , fp16_global_param_num = " << fp16_global_param_num; // Step 4: Get LearningRate, Moment1, Moment2, Beta1Pow, Beta2Pow, - // WeightDecay, GlobalScale, FoundInf + // GlobalScale, FoundInf const auto *global_scale = GetInputTensorPtr(ctx, "GlobalScale"); const auto *lr = GetInputTensorPtr(ctx, "LearningRate"); int64_t partial_numel = 0; @@ -1065,14 +1116,15 @@ class DistributedFusedLambOpKernel GetSameInOutTensorPtr(ctx, place, "Beta1Pow", "Beta1PowOut"); auto *beta2pow = GetSameInOutTensorPtr(ctx, place, "Beta2Pow", "Beta2PowOut"); - const float *weight_decay = GetInputTensorPtr(ctx, "WeightDecay"); auto *found_inf_t = ctx.Output("FoundInf"); found_inf_t->Resize({1}); auto *found_inf = found_inf_t->mutable_data(place); - // Step 5: Get attributes beta1, beta2, epsilon, max_grad_norm, ring_id, + // Step 5: Get attributes weight_decay, beta1, beta2, epsilon, + // max_grad_norm, ring_id, // use_master_param_norm, is_grad_scaled_by_nranks + auto weight_decay = ctx.Attr("weight_decay"); auto beta1 = ctx.Attr("beta1"); auto beta2 = ctx.Attr("beta2"); auto epsilon = ctx.Attr("epsilon"); @@ -1105,7 +1157,8 @@ class DistributedFusedLambOpKernel platform::float16 *fp16_sum_grad; auto fp32_numel_each_device = fp32_numel / num_devices; auto fp16_numel_each_device = fp16_numel / num_devices; - if (num_devices > 1) { + if (num_devices > 1 || + (max_global_grad_norm > 0 && !clip_after_allreduce)) { auto ptr = sum_grad_buffer.Alloc( fp32_numel_each_device * sizeof(float) + fp16_numel_each_device * sizeof(platform::float16)); @@ -1181,7 +1234,11 @@ class DistributedFusedLambOpKernel float, platform::float16><<<1, 1, 0, stream>>>( global_scale, max_global_grad_norm, fp32_square_grad_norm, fp32_scale, fp16_scale, clip_scale); - VLOG(1) << "Grad scale: " << FlattenToString(fp32_scale, 1, place); + if (fp32_scale) { + VLOG(1) << "Grad scale: " << FlattenToString(fp32_scale, 1, place); + } else { + VLOG(1) << "Grad scale: " << FlattenToString(fp16_scale, 1, place); + } if (num_devices > 1) { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32, @@ -1218,36 +1275,56 @@ class DistributedFusedLambOpKernel VLOG(10) << "ReduceScatter done"; // Step 7: update the moment1, moment2. Calcuate the trust_ratio_div + auto *fused_offsets_t = ctx.Input("FusedParamOffsets"); + auto *fused_offsets = fused_offsets_t->data(); + auto *fp32_partial_fused_offsets_t = + ctx.Input("FP32ShardFusedParamOffsets"); + const auto *fp32_partial_fused_offsets = + fp32_partial_fused_offsets_t->data(); + auto *fp16_partial_fused_offsets_t = + ctx.Input("FP16ShardFusedParamOffsets"); + const auto *fp16_partial_fused_offsets = + fp16_partial_fused_offsets_t->data(); + + VLOG(1) << "FusedParamOffsets: " + << FlattenToString(fused_offsets, fused_offsets_t->numel(), + fused_offsets_t->place()); + VLOG(1) << "FP32ShardFusedParamOffsets: " + << FlattenToString(fp32_partial_fused_offsets, + fp32_partial_fused_offsets_t->numel(), + fp32_partial_fused_offsets_t->place()); + VLOG(1) << "FP16ShardFusedParamOffsets: " + << FlattenToString(fp16_partial_fused_offsets, + fp16_partial_fused_offsets_t->numel(), + fp16_partial_fused_offsets_t->place()); + memory::Buffer trust_ratio_div_buffer(place); auto *trust_ratio_div = trust_ratio_div_buffer.Alloc(partial_numel); auto fp32_offset = rank * fp32_numel_each_device; auto fp16_offset = rank * fp16_numel_each_device; if (has_fp32_param) { - auto config = - platform::GetGpuLaunchConfig1D(dev_ctx, fp32_numel_each_device); VLOG(10) << "Update FP32 Moment and TrustRatioDiv starts"; - UpdateLambMoment<<>>( + MultiTensorUpdateLambMomentAndTrustRatioDiv( + dev_ctx, fp32_partial_fused_offsets, fp32_local_param_num, fp32_param + fp32_offset, fp32_sum_grad, fp32_square_grad_norm, - global_scale, indices + fp32_offset, weight_decay, beta1pow, beta2pow, - moment1, moment2, trust_ratio_div, beta1, beta2, epsilon, - max_global_grad_norm, fp32_numel_each_device, rescale_grad); + global_scale, beta1pow, beta2pow, moment1, moment2, trust_ratio_div, + found_inf, weight_decay, fp32_weight_decay_end_idx, beta1, beta2, + epsilon, max_global_grad_norm, rescale_grad); VLOG(10) << "Update FP32 Moment and TrustRatioDiv done"; } float *master_param = nullptr; if (has_fp16_param) { master_param = fp32_param + fp32_numel; - auto config = - platform::GetGpuLaunchConfig1D(dev_ctx, fp16_numel_each_device); VLOG(10) << "Update FP16 Moment and TrustRatioDiv starts"; - UpdateLambMoment<<>>( + auto tmp_found_inf = has_fp32_param ? nullptr : found_inf; + MultiTensorUpdateLambMomentAndTrustRatioDiv( + dev_ctx, fp16_partial_fused_offsets, fp16_local_param_num, master_param + fp16_offset, fp16_sum_grad, fp32_square_grad_norm, - global_scale, indices + fp32_numel + fp16_offset, weight_decay, - beta1pow, beta2pow, moment1 + fp32_numel_each_device, + global_scale, beta1pow, beta2pow, moment1 + fp32_numel_each_device, moment2 + fp32_numel_each_device, - trust_ratio_div + fp32_numel_each_device, beta1, beta2, epsilon, - max_global_grad_norm, fp16_numel_each_device, rescale_grad); + trust_ratio_div + fp32_numel_each_device, tmp_found_inf, weight_decay, + fp16_weight_decay_end_idx, beta1, beta2, epsilon, + max_global_grad_norm, rescale_grad); VLOG(10) << "Update FP16 Moment and TrustRatioDiv done"; } @@ -1257,30 +1334,6 @@ class DistributedFusedLambOpKernel memory::Buffer square_norm_buffer(place); auto *param_square_norm = square_norm_buffer.Alloc(2 * param_num); auto *trust_ratio_div_square_norm = param_square_norm + param_num; - - auto *fused_offsets_t = ctx.Input("FusedParamOffsets"); - auto *fused_offsets = fused_offsets_t->data(); - auto *fp32_partial_fused_offsets_t = - ctx.Input("FP32ShardFusedParamOffsets"); - const auto *fp32_partial_fused_offsets = - fp32_partial_fused_offsets_t->data(); - auto *fp16_partial_fused_offsets_t = - ctx.Input("FP16ShardFusedParamOffsets"); - const auto *fp16_partial_fused_offsets = - fp16_partial_fused_offsets_t->data(); - - VLOG(1) << "FusedParamOffsets: " - << FlattenToString(fused_offsets, fused_offsets_t->numel(), - fused_offsets_t->place()); - VLOG(1) << "FP32ShardFusedParamOffsets: " - << FlattenToString(fp32_partial_fused_offsets, - fp32_partial_fused_offsets_t->numel(), - fp32_partial_fused_offsets_t->place()); - VLOG(1) << "FP16ShardFusedParamOffsets: " - << FlattenToString(fp16_partial_fused_offsets, - fp16_partial_fused_offsets_t->numel(), - fp16_partial_fused_offsets_t->place()); - if (num_devices > 1) { if (use_master_param_norm) { FillZeroWithPtr(param_square_norm + fp32_global_param_num, @@ -1296,11 +1349,11 @@ class DistributedFusedLambOpKernel fp16_partial_fused_offsets, fp16_local_param_num, param_square_norm + fp16_local_start_idx); } else { - // NOTE: extra computation is performed. We can improve this performance - // if needed in the future. MultiTensorL2Norm( - place, stream, fp16_param, fused_offsets + fp32_global_param_num, - fp16_global_param_num, param_square_norm + fp32_global_param_num); + place, stream, fp16_param + fused_offsets[fp16_local_start_idx] - + fused_offsets[fp32_global_param_num], + fused_offsets + fp16_local_start_idx, fp16_local_param_num, + param_square_norm + fp16_local_start_idx); } MultiTensorL2Norm(place, stream, trust_ratio_div, @@ -1333,26 +1386,29 @@ class DistributedFusedLambOpKernel // Step 9: update parameter, beta1pow, beta2pow. All gather parameters. if (has_fp32_param) { - LambUpdateParamAndBetaPows( - dev_ctx, trust_ratio_div, lr, indices + fp32_offset, - param_square_norm, trust_ratio_div_square_norm, fp32_square_grad_norm, - &beta1pow, &beta2pow, &found_inf, beta1, beta2, - fp32_numel_each_device, fp32_param + fp32_offset, nullptr, stream); + MultiTensorUpdateLambParamAndBetaPows( + dev_ctx, fp32_partial_fused_offsets, fp32_local_param_num, + trust_ratio_div, lr, param_square_norm + fp32_local_start_idx, + trust_ratio_div_square_norm + fp32_local_start_idx, found_inf, + fp32_param + fp32_offset, nullptr, beta1pow, beta2pow, beta1, beta2); if (num_devices > 1) { // ncclAllGather PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( fp32_param + fp32_offset, fp32_param, fp32_numel_each_device, ncclFloat32, comm, stream)); } + + beta1pow = nullptr; + beta2pow = nullptr; } if (has_fp16_param) { - LambUpdateParamAndBetaPows( - dev_ctx, trust_ratio_div + fp32_numel_each_device, lr, - indices + fp32_numel + fp16_offset, param_square_norm, - trust_ratio_div_square_norm, fp32_square_grad_norm, &beta1pow, - &beta2pow, &found_inf, beta1, beta2, fp16_numel_each_device, - fp16_param + fp16_offset, master_param + fp16_offset, stream); - + MultiTensorUpdateLambParamAndBetaPows( + dev_ctx, fp16_partial_fused_offsets, fp16_local_param_num, + trust_ratio_div + fp32_numel_each_device, lr, + param_square_norm + fp16_local_start_idx, + trust_ratio_div_square_norm + fp16_local_start_idx, found_inf, + fp16_param + fp16_offset, master_param + fp16_offset, beta1pow, + beta2pow, beta1, beta2); if (num_devices > 1) { // ncclAllGather PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( diff --git a/paddle/fluid/operators/optimizers/multi_tensor_apply.h b/paddle/fluid/operators/optimizers/multi_tensor_apply.h index 5d8d03c733d..179e8f45254 100644 --- a/paddle/fluid/operators/optimizers/multi_tensor_apply.h +++ b/paddle/fluid/operators/optimizers/multi_tensor_apply.h @@ -94,11 +94,40 @@ static __global__ void MultiTensorApplyCUDAKernel( args...); } -template -static void MultiTensorApply(Functor functor, gpuStream_t stream, - const int *offsets, int n, int chunk_size, - Args... args) { +template +class MultiTensorLauncher { + public: + MultiTensorLauncher( + const TensorMetaList &meta, + const int &chunk_id, const int &chunk_size, const int &block_dim, + const gpuStream_t &stream) + : meta_(meta), + chunk_id_(chunk_id), + chunk_size_(chunk_size), + block_dim_(block_dim), + stream_(stream) {} + + template + void Launch(Functor &&functor, Args &&... args) const { + MultiTensorApplyCUDAKernel< + Functor, MaxTensorNumPerLaunch, + MaxChunkNumPerLaunch><<>>( + functor, meta_, chunk_size_, args...); + } + + private: + const TensorMetaList &meta_; + const int &chunk_id_; + const int &chunk_size_; + const int &block_dim_; + const gpuStream_t &stream_; +}; + +template +static void MultiTensorApplyWithCallback(gpuStream_t stream, const int *offsets, + int n, int chunk_size, int block_dim, + Callback &&callback) { if (n == 0) return; constexpr auto NumTensor = MaxTensorNumPerLaunch; @@ -110,6 +139,11 @@ static void MultiTensorApply(Functor functor, gpuStream_t stream, int numel_offset = 0; metas.start_tensor_id = 0; metas.start_chunk_id = 0; + int launch_num = 0; + + MultiTensorLauncher launcher( + metas, chunk_id, chunk_size, block_dim, stream); + for (int i = 0; i < n; ++i) { auto length = offsets[i + 1] - offsets[i]; if (tensor_id == 0) { @@ -132,9 +166,8 @@ static void MultiTensorApply(Functor functor, gpuStream_t stream, bool last_chunk = (i + 1 == n && j + 1 == chunk_num); if (tensor_full || block_full || last_chunk) { - MultiTensorApplyCUDAKernel<<>>( - functor, metas, chunk_size, args...); + callback(launcher, launch_num); + ++launch_num; chunk_id = 0; if (j + 1 == chunk_num) { // chunk for the current tensor is full metas.start_chunk_id = 0; @@ -152,5 +185,17 @@ static void MultiTensorApply(Functor functor, gpuStream_t stream, } } +template +static void MultiTensorApply(Functor functor, gpuStream_t stream, + const int *offsets, int n, int chunk_size, + int block_dim, Args &&... args) { + auto callback = [&](const MultiTensorLauncher &launcher, + int i) { launcher.Launch(functor, args...); }; + MultiTensorApplyWithCallback( + stream, offsets, n, chunk_size, block_dim, callback); +} + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/distributed_fused_lamb_test_base.py b/python/paddle/fluid/tests/unittests/distributed_fused_lamb_test_base.py index e0529c5d5f8..00d2a1f71d6 100644 --- a/python/paddle/fluid/tests/unittests/distributed_fused_lamb_test_base.py +++ b/python/paddle/fluid/tests/unittests/distributed_fused_lamb_test_base.py @@ -144,6 +144,11 @@ def run_model(use_distributed_lamb, use_fp16, use_master_param_norm, **kwargs): grad_clip = kwargs.get('grad_clip', None) clip_after_allreduce = kwargs.get('clip_after_allreduce', True) + parameters = [p.name for p in main.all_parameters()] + exclude_fn = lambda var: var.name in parameters[::4] + kwargs['exclude_from_weight_decay_fn'] = exclude_fn + kwargs['lamb_weight_decay'] = 0.1 + if use_distributed_lamb: optimizer_class = DistributedFusedLamb kwargs = dict(kwargs) diff --git a/python/paddle/incubate/optimizer/distributed_fused_lamb.py b/python/paddle/incubate/optimizer/distributed_fused_lamb.py index e7c3cfbb7b9..cc33a909632 100644 --- a/python/paddle/incubate/optimizer/distributed_fused_lamb.py +++ b/python/paddle/incubate/optimizer/distributed_fused_lamb.py @@ -171,10 +171,7 @@ class DistributedFusedLamb(Optimizer): moment2.is_distributed = True beta1pow = self._create_persistable_var('beta1pow') beta2pow = self._create_persistable_var('beta2pow') - fused_indices = self._create_persistable_var( - 'fused_indices', dtype='int32') - weight_decay = self._create_persistable_var('weight_decay') - weight_decay.is_distributed = True + param_info = self._create_persistable_var('param_info', dtype='int32') param_info.is_distributed = True @@ -189,17 +186,20 @@ class DistributedFusedLamb(Optimizer): 'fp16_partial_fused_offsets', dtype='int32') fp16_partial_fused_offsets.is_distributed = True + param_order = self._create_persistable_var('param_order', dtype='int32') + param_order.is_distributed = True + rank = get_rank() nranks = get_world_size() scale = self._get_or_create_scale() params = [p for p, _ in params_grads] grads = [g for _, g in params_grads] - weight_decay_values = [self._weight_decay] * len(params) + apply_weight_decay = [1] * len(params) if self._exclude_from_weight_decay_fn is not None: for i, p in enumerate(params): if self._exclude_from_weight_decay_fn(p): - weight_decay_values[i] = 0.0 + apply_weight_decay[i] = 0 startup_block = self.helper.startup_program.global_block() for g in grads: @@ -225,8 +225,6 @@ class DistributedFusedLamb(Optimizer): 'Moment2': [moment2], 'Beta1Pow': [beta1pow], 'Beta2Pow': [beta2pow], - 'FusedIndices': [fused_indices], - 'WeightDecay': [weight_decay], 'GlobalScale': [scale], 'ParamInfo': [param_info], 'ParamOut': params, @@ -235,12 +233,13 @@ class DistributedFusedLamb(Optimizer): 'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets], 'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets], 'FusedParamOffsets': [fused_offsets], + 'ParamOrder': [param_order], }, attrs={ 'alignment': self._alignment, 'rank': rank, 'nranks': nranks, - 'weight_decay': weight_decay_values, + 'apply_weight_decay': apply_weight_decay, 'moment1': 0.0, 'moment2': 0.0, 'beta1': self._beta1, @@ -272,8 +271,6 @@ class DistributedFusedLamb(Optimizer): 'Moment2': [moment2], 'Beta1Pow': [beta1pow], 'Beta2Pow': [beta2pow], - 'FusedIndices': [fused_indices], - 'WeightDecay': [weight_decay], 'GlobalScale': [scale], 'ParamInfo': [param_info], 'Param': params, @@ -281,6 +278,7 @@ class DistributedFusedLamb(Optimizer): 'FusedParamOffsets': [fused_offsets], 'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets], 'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets], + 'ParamOrder': [param_order], }, outputs={ 'FP32FusedParamOut': [fp32_fused_param], @@ -294,6 +292,7 @@ class DistributedFusedLamb(Optimizer): 'FoundInf': [self._found_inf], }, attrs={ + 'weight_decay': self._weight_decay, 'beta1': self._beta1, 'beta2': self._beta2, 'epsilon': self._epsilon, -- GitLab