From 108aeb28704e64a54f82b8a59266a4e9633f9949 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 28 Apr 2022 12:02:23 +0800 Subject: [PATCH] Add gradient merge for DistributedFusedLamb optimizer (#40177) * add gradient merge for DistributedFusedLamb * use master acc gradient * fix CI ut * polish * remove math_function_impl.h change * fix test_update_loss_scaling_op.py * try to fix XPU/NPU CI * add gm ut --- .../operators/amp/update_loss_scaling_op.cc | 24 ++- .../operators/amp/update_loss_scaling_op.cu | 24 ++- .../operators/amp/update_loss_scaling_op.h | 60 +++++- .../amp/update_loss_scaling_op_npu.cc | 5 +- .../optimizers/distributed_fused_lamb_op.cc | 10 + .../optimizers/distributed_fused_lamb_op.cu | 181 +++++++++++++++++- .../fluid/contrib/mixed_precision/amp_nn.py | 6 +- .../contrib/mixed_precision/decorator.py | 2 +- .../fluid/tests/unittests/CMakeLists.txt | 2 + .../distributed_fused_lamb_test_base.py | 18 +- ...est_distributed_fused_lamb_op_with_clip.py | 5 +- ...buted_fused_lamb_op_with_gradient_merge.py | 28 +++ .../optimizer/distributed_fused_lamb.py | 35 ++++ 13 files changed, 369 insertions(+), 31 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_gradient_merge.py diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op.cc b/paddle/fluid/operators/amp/update_loss_scaling_op.cc index b974f60672..8354650df0 100644 --- a/paddle/fluid/operators/amp/update_loss_scaling_op.cc +++ b/paddle/fluid/operators/amp/update_loss_scaling_op.cc @@ -68,6 +68,18 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel { return framework::OpKernelType(dtype, ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { +#ifndef PADDLE_WITH_XPU + if (var_name == "FoundInfinite" || var_name == "StopUpdate") { + return expected_kernel_type; + } +#endif + return framework::OperatorWithKernel::GetKernelTypeForVar( + var_name, tensor, expected_kernel_type); + } }; class UpdateLossScalingOpMaker : public framework::OpProtoAndCheckerMaker { @@ -93,6 +105,10 @@ class UpdateLossScalingOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("LossScaling", "(Tensor) 1-dim tensor, updated loss scaling."); AddOutput("OutGoodSteps", "(Tensor) 1-dim tensor, pdated good steps."); AddOutput("OutBadSteps", "(Tensor) 1-dim tensor, updated bad steps."); + AddOutput("StopUpdate", + "(Tensor) 1-dim tensor. Stop updating loss scaling, and just " + "zero inputs. It has higher priority than Attr(stop_update).") + .AsDispensable(); AddAttr("incr_every_n_steps", "A value represents increasing loss scaling every n " "consecutive steps with finite gradients."); @@ -131,8 +147,8 @@ decr_every_n_nan_or_inf steps and each step some gradients are infinite. } }; -template -class UpdateLossScalingFunctor { +template +class UpdateLossScalingFunctor { public: void operator()(const platform::CPUDeviceContext& ctx, const bool* found_inf_data, const T* pre_loss_scaling_data, @@ -141,6 +157,10 @@ class UpdateLossScalingFunctor { const int decr_every_n_nan_or_inf, const float incr_ratio, const float decr_ratio, T* updated_loss_scaling_data, int* good_out_data, int* bad_out_data) const { + PADDLE_ENFORCE_EQ( + IsFoundInfOnCPU, true, + platform::errors::InvalidArgument( + "The Input(FoundInfinite) should be on the CPUPlace.")); Update(found_inf_data, pre_loss_scaling_data, good_in_data, bad_in_data, incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio, updated_loss_scaling_data, good_out_data, diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op.cu b/paddle/fluid/operators/amp/update_loss_scaling_op.cu index 6d9cd96a3f..43f8f84578 100644 --- a/paddle/fluid/operators/amp/update_loss_scaling_op.cu +++ b/paddle/fluid/operators/amp/update_loss_scaling_op.cu @@ -21,9 +21,9 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template __global__ void GpuUpdateLossScaling( - const bool* found_inf_data, const T* pre_loss_scaling_data, + const FoundNanInfFlagT found_inf_data, const T* pre_loss_scaling_data, const int* good_in_data, const int* bad_in_data, const int incr_every_n_steps, const int decr_every_n_nan_or_inf, const float incr_ratio, const float decr_ratio, @@ -70,8 +70,9 @@ __global__ void FusedFillIf(T** outs, const size_t xs_size, } } -template -class UpdateLossScalingFunctor { +template +class UpdateLossScalingFunctor { public: void operator()(const platform::CUDADeviceContext& dev_ctx, const bool* found_inf_data, const T* pre_loss_scaling_data, @@ -80,10 +81,17 @@ class UpdateLossScalingFunctor { const int decr_every_n_nan_or_inf, const float incr_ratio, const float decr_ratio, T* updated_loss_scaling_data, int* good_out_data, int* bad_out_data) const { - GpuUpdateLossScaling<<<1, 1, 0, dev_ctx.stream()>>>( - found_inf_data, pre_loss_scaling_data, good_in_data, bad_in_data, - incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio, - updated_loss_scaling_data, good_out_data, bad_out_data); + if (IsFoundInfOnCPU) { + GpuUpdateLossScaling<<<1, 1, 0, dev_ctx.stream()>>>( + *found_inf_data, pre_loss_scaling_data, good_in_data, bad_in_data, + incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio, + updated_loss_scaling_data, good_out_data, bad_out_data); + } else { + GpuUpdateLossScaling<<<1, 1, 0, dev_ctx.stream()>>>( + found_inf_data, pre_loss_scaling_data, good_in_data, bad_in_data, + incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio, + updated_loss_scaling_data, good_out_data, bad_out_data); + } } }; diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op.h b/paddle/fluid/operators/amp/update_loss_scaling_op.h index d6eddd36a4..41eb94247f 100644 --- a/paddle/fluid/operators/amp/update_loss_scaling_op.h +++ b/paddle/fluid/operators/amp/update_loss_scaling_op.h @@ -25,6 +25,7 @@ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" #include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { namespace operators { @@ -40,8 +41,16 @@ inline HOSTDEVICE bool check_finite(T value) { #endif } -template -inline HOSTDEVICE void Update(const bool* found_inf_data, +inline HOSTDEVICE bool IsFoundNanInf(const bool found_nan_inf_data) { + return found_nan_inf_data; +} + +inline HOSTDEVICE bool IsFoundNanInf(const bool* found_nan_inf_data) { + return *found_nan_inf_data; +} + +template +inline HOSTDEVICE void Update(const FoundInfFlagT found_inf_data, const T* pre_loss_scaling_data, const int* good_in_data, const int* bad_in_data, const int incr_every_n_steps, @@ -49,7 +58,7 @@ inline HOSTDEVICE void Update(const bool* found_inf_data, const float incr_ratio, const float decr_ratio, T* updated_loss_scaling_data, int* good_out_data, int* bad_out_data) { - if (*found_inf_data) { + if (IsFoundNanInf(found_inf_data)) { *good_out_data = 0; *bad_out_data = *bad_in_data + 1; if (*bad_out_data == decr_every_n_nan_or_inf) { @@ -72,7 +81,7 @@ inline HOSTDEVICE void Update(const bool* found_inf_data, } } -template +template class UpdateLossScalingFunctor { public: void operator()(const DeviceContext& dev_ctx, const bool* found_inf_data, @@ -106,9 +115,33 @@ class UpdateLossScalingKernel : public framework::OpKernel { platform::errors::InvalidArgument( "FoundInfinite must has only one element.")); const bool* found_inf_data = found_inf->data(); + bool is_found_inf_on_cpu = platform::is_cpu_place(found_inf->place()); + + if (is_found_inf_on_cpu) { + if (*found_inf_data) { + phi::funcs::SetConstant set_constant; + for (auto* out : outs) { + out->mutable_data(dev_ctx.GetPlace()); + set_constant(dev_ctx, out, static_cast(0)); + } + } + } else { + LazyZeros{}(dev_ctx, found_inf_data, xs, outs); + } - LazyZeros{}(dev_ctx, found_inf_data, xs, outs); - const bool stop_update = ctx.Attr("stop_update"); + const auto* stop_update_tensor = ctx.Input("StopUpdate"); + bool stop_update = false; + if (stop_update_tensor && stop_update_tensor->IsInitialized()) { + if (platform::is_cpu_place(stop_update_tensor->place())) { + stop_update = stop_update_tensor->data()[0]; + } else { + framework::Tensor tmp_tensor; + framework::TensorCopySync(*stop_update_tensor, platform::CPUPlace(), + &tmp_tensor); + stop_update = tmp_tensor.data()[0]; + } + } + stop_update |= ctx.Attr("stop_update"); if (stop_update) { return; } @@ -133,10 +166,17 @@ class UpdateLossScalingKernel : public framework::OpKernel { ctx.Attr("decr_every_n_nan_or_inf"); const float incr_ratio = ctx.Attr("incr_ratio"); const float decr_ratio = ctx.Attr("decr_ratio"); - UpdateLossScalingFunctor{}( - dev_ctx, found_inf_data, pre_loss_scaling_data, good_in_data, - bad_in_data, incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, - decr_ratio, updated_loss_scaling_data, good_out_data, bad_out_data); + if (is_found_inf_on_cpu) { + UpdateLossScalingFunctor{}( + dev_ctx, found_inf_data, pre_loss_scaling_data, good_in_data, + bad_in_data, incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, + decr_ratio, updated_loss_scaling_data, good_out_data, bad_out_data); + } else { + UpdateLossScalingFunctor{}( + dev_ctx, found_inf_data, pre_loss_scaling_data, good_in_data, + bad_in_data, incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, + decr_ratio, updated_loss_scaling_data, good_out_data, bad_out_data); + } } }; diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc b/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc index 1393da7dd5..5808841333 100644 --- a/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc +++ b/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc @@ -131,7 +131,8 @@ void Update(const platform::NPUDeviceContext& ctx, } template -class UpdateLossScalingFunctor { +class UpdateLossScalingFunctor { public: void operator()(const platform::NPUDeviceContext& dev_ctx, const std::vector found_inf_vec, @@ -236,7 +237,7 @@ class UpdateLossScalingNPUKernel : public framework::OpKernel { ctx.Attr("decr_every_n_nan_or_inf"); const float incr_ratio = ctx.Attr("incr_ratio"); const float decr_ratio = ctx.Attr("decr_ratio"); - UpdateLossScalingFunctor{}( + UpdateLossScalingFunctor{}( dev_ctx, found_inf_vec, pre_loss_scaling, good_in, bad_in, incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio, updated_loss_scaling, good_out, bad_out); diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc index 161483c342..0159e250d3 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc @@ -100,6 +100,10 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker { .AsDispensable(); AddOutput("FP16FusedParamOut", "The updated FP16FusedParam.") .AsDispensable(); + AddOutput("FP32AccFusedGrad", "The accumulated FP32 gradients.") + .AsDispensable(); + AddOutput("FP16AccFusedGrad", "The accumulated FP16 gradients.") + .AsDispensable(); AddOutput("Moment1Out", "The updated Moment1."); AddOutput("Moment2Out", "The updated Moment2."); @@ -110,8 +114,14 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker { .AsDuplicable(); AddOutput("FoundInf", "Whether there is NaN/Inf"); + AddOutput("AccStep", "The training steps.").AsDispensable(); + AddOutput("StopUpdate", + "Whether the parameter updating is stopped when the gradient " + "accumulated steps is less than Attr(acc_steps).") + .AsDispensable(); AddOutput("Step", "The global step which excludes the NaN/Inf step."); + AddAttr("acc_steps", "The gradient accumulation steps.").SetDefault(1); AddAttr("beta1", "The initial Beta1Pow value."); AddAttr("beta2", "The initial Beta2Pow value."); AddAttr("epsilon", diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu index f445a140f2..c857c6de4d 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu @@ -1041,6 +1041,58 @@ static void CheckHasNanInfGrad(const float *fp32_grad, int fp32_numel, } } +template +static __global__ void ElementwiseAddWithCastCUDAKernel(const T1 *x, + const T2 *y, T3 *z, + int n) { + static_assert(sizeof(T1) <= sizeof(T2), + "sizeof(T1) must be smaller than sizeof(T2)."); + using MT = MasterT; + + int i = (threadIdx.x + blockIdx.x * blockDim.x) * VecSize; + int stride = (blockDim.x * gridDim.x) * VecSize; + for (; i + VecSize <= n; i += stride) { + phi::AlignedVector x_vec; + phi::AlignedVector y_vec; + phi::AlignedVector z_vec; + phi::Load(x + i, &x_vec); + phi::Load(y + i, &y_vec); +#pragma unroll + for (int j = 0; j < VecSize; ++j) { + auto x_tmp = static_cast(x_vec[j]); + auto y_tmp = static_cast(y_vec[j]); + z_vec[j] = static_cast(x_tmp + y_tmp); + } + phi::Store(z_vec, z + i); + } + + for (; i < n; ++i) { + auto x_tmp = static_cast(x[i]); + auto y_tmp = static_cast(y[i]); + z[i] = static_cast(x_tmp + y_tmp); + } +} + +template +static void LaunchElementwiseAddWithCastKernel( + const platform::CUDADeviceContext &dev_ctx, const T1 *x, const T2 *y, T3 *z, + int n, gpuStream_t stream) { + int vec_size = + std::min(std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0)), + GetChunkedVecSize(z, 0)); + auto config = platform::GetGpuLaunchConfig1D(dev_ctx, n, vec_size); + +#define PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL \ + do { \ + ElementwiseAddWithCastCUDAKernel<<< \ + config.block_per_grid, config.thread_per_block, 0, stream>>>(x, y, z, \ + n); \ + } while (0) + + PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL); +#undef PD_LAUNCH_ELEMENTWISE_ADD_WITH_CAST_KERNEL +} + template class DistributedFusedLambOpKernel : public framework::OpKernel { @@ -1051,6 +1103,9 @@ class DistributedFusedLambOpKernel auto stream = dev_ctx.stream(); auto place = dev_ctx.GetPlace(); + auto *found_inf_t = ctx.Output("FoundInf"); + found_inf_t->Resize({1}); + // Step 1: Get fp16 param and grad tensors int64_t fp16_numel; auto *fp16_param = GetSameInOutTensorPtr( @@ -1095,6 +1150,128 @@ class DistributedFusedLambOpKernel "Too many parameter number. Only <= %d is supported.", std::numeric_limits::max())); + auto acc_steps = ctx.Attr("acc_steps"); + PADDLE_ENFORCE_GE( + acc_steps, 1, + platform::errors::InvalidArgument( + "The gradient accumulation steps should be not less than 1.")); + if (acc_steps > 1) { + auto *step_t = ctx.Output("AccStep"); + PADDLE_ENFORCE_NOT_NULL( + step_t, + platform::errors::InvalidArgument( + "Output(AccStep) cannot be nullptr when Attr(acc_steps) > 1.")); + bool is_initialized = step_t->IsInitialized(); + int64_t *step_ptr; + if (is_initialized) { + step_ptr = step_t->mutable_data(platform::CPUPlace()); + ++(*step_ptr); + } else { + step_t->Resize({1}); + step_ptr = step_t->mutable_data(platform::CPUPlace()); + *step_ptr = 1; + } + int64_t rounded_step = (*step_ptr) % acc_steps; + + float *fp32_acc_grad = nullptr; + if (has_fp32_param) { + auto *fp32_acc_grad_t = + ctx.Output("FP32AccFusedGrad"); + PADDLE_ENFORCE_NOT_NULL( + fp32_acc_grad_t, platform::errors::InvalidArgument( + "Output(FP32AccFusedGrad) cannot be nullptr " + "when Attr(acc_steps) > 1.")); + if (!fp32_acc_grad_t->IsInitialized()) { + fp32_acc_grad_t->Resize({static_cast(fp32_numel)}); + fp32_acc_grad = fp32_acc_grad_t->mutable_data(place); + } else { + fp32_acc_grad = fp32_acc_grad_t->data(); + } + } + + platform::float16 *fp16_acc_grad = nullptr; + float *master_acc_grad = nullptr; + if (has_fp16_param) { + auto *fp16_acc_grad_t = + ctx.Output("FP16AccFusedGrad"); + PADDLE_ENFORCE_NOT_NULL( + fp16_acc_grad_t, platform::errors::InvalidArgument( + "Output(FP16AccFusedGrad) cannot be nullptr " + "when Attr(acc_steps) > 1.")); + if (!fp16_acc_grad_t->IsInitialized()) { + fp16_acc_grad_t->Resize({static_cast(3 * fp16_numel)}); + fp16_acc_grad = + fp16_acc_grad_t->mutable_data(place); + } else { + fp16_acc_grad = fp16_acc_grad_t->data(); + } + master_acc_grad = reinterpret_cast(fp16_acc_grad + fp16_numel); + } + + // Inplace addto + if (has_fp32_param) { + if (rounded_step == 1) { + memory::Copy(place, fp32_acc_grad, place, fp32_grad, + fp32_numel * sizeof(float), stream); + } else { + LaunchElementwiseAddWithCastKernel(dev_ctx, fp32_grad, fp32_acc_grad, + fp32_acc_grad, fp32_numel, stream); + } + } + + if (has_fp16_param) { + if (acc_steps == 2) { + if (rounded_step == 0) { + LaunchElementwiseAddWithCastKernel(dev_ctx, fp16_acc_grad, + fp16_grad, fp16_acc_grad, + fp16_numel, stream); + } else { + memory::Copy(place, fp16_acc_grad, place, fp16_grad, + fp16_numel * sizeof(platform::float16), stream); + } + } else { // acc_steps >= 3 + if (rounded_step == 0) { + LaunchElementwiseAddWithCastKernel(dev_ctx, fp16_grad, + master_acc_grad, fp16_acc_grad, + fp16_numel, stream); + } else if (rounded_step == 1) { + memory::Copy(place, fp16_acc_grad, place, fp16_grad, + fp16_numel * sizeof(platform::float16), stream); + } else if (rounded_step == 2) { + LaunchElementwiseAddWithCastKernel(dev_ctx, fp16_grad, + fp16_acc_grad, master_acc_grad, + fp16_numel, stream); + } else { + LaunchElementwiseAddWithCastKernel(dev_ctx, fp16_grad, + master_acc_grad, master_acc_grad, + fp16_numel, stream); + } + } + } + + auto *stop_update_t = ctx.Output("StopUpdate"); + stop_update_t->Resize({1}); + auto *stop_update = + stop_update_t->mutable_data(platform::CPUPlace()); + + auto *found_inf_cpu = + found_inf_t->mutable_data(platform::CPUPlace()); + + if (rounded_step != 0) { + *stop_update = true; + auto *found_inf_cpu = + found_inf_t->mutable_data(platform::CPUPlace()); + *found_inf_cpu = false; + return; + } else { + // swap pointer + fp32_grad = fp32_acc_grad; + fp16_grad = fp16_acc_grad; + *stop_update = false; + found_inf_t->clear(); + } + } + // Step 3: Get ParamInfo const auto *param_info_tensor = GetInputTensorPtr(ctx, "ParamInfo"); auto fp32_local_start_idx = param_info_tensor[0]; @@ -1122,7 +1299,7 @@ class DistributedFusedLambOpKernel << " , fp16_global_param_num = " << fp16_global_param_num; // Step 4: Get LearningRate, Moment1, Moment2, Beta1Pow, Beta2Pow, - // GlobalScale, FoundInf + // GlobalScale const auto *global_scale = GetInputTensorPtr(ctx, "GlobalScale"); const auto *lr = GetInputTensorPtr(ctx, "LearningRate"); int64_t partial_numel = 0; @@ -1157,8 +1334,6 @@ class DistributedFusedLambOpKernel auto *beta2pow = GetSameInOutTensorPtr(ctx, place, "Beta2Pow", "Beta2PowOut"); - auto *found_inf_t = ctx.Output("FoundInf"); - found_inf_t->Resize({1}); auto *found_inf = found_inf_t->mutable_data(place); // Step 5: Get attributes weight_decay, beta1, beta2, epsilon, diff --git a/python/paddle/fluid/contrib/mixed_precision/amp_nn.py b/python/paddle/fluid/contrib/mixed_precision/amp_nn.py index 588eb2a29f..c5b9b9e71f 100644 --- a/python/paddle/fluid/contrib/mixed_precision/amp_nn.py +++ b/python/paddle/fluid/contrib/mixed_precision/amp_nn.py @@ -129,9 +129,13 @@ def update_loss_scaling(x, 'decr_every_n_nan_or_inf': decr_every_n_nan_or_inf, 'incr_ratio': incr_ratio, 'decr_ratio': decr_ratio, - 'stop_update': stop_update } + if isinstance(stop_update, Variable): + inputs['StopUpdate'] = stop_update + else: + attrs['stop_update'] = stop_update + helper.append_op( type='update_loss_scaling', inputs=inputs, outputs=outputs, attrs=attrs) diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index c6e2bcb8b1..c3720396e1 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -432,7 +432,7 @@ class OptimizerWithMixedPrecision(object): self._decr_every_n_nan_or_inf, self._incr_ratio, self._decr_ratio, - stop_update=False, + stop_update=self._optimizer._get_stop_update_var(), name="update_loss_scaling") return diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 12ed7b975a..15dd3d8b8f 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -914,6 +914,7 @@ set_tests_properties(test_parallel_executor_crf test_sync_batch_norm_op test_inp test_parallel_executor_seresnext_with_fuse_all_reduce_gpu test_distributed_fused_lamb_op_with_clip test_distributed_fused_lamb_op_without_clip + test_distributed_fused_lamb_op_with_gradient_merge test_parallel_executor_fetch_isolated_var PROPERTIES LABELS "RUN_TYPE=DIST") @@ -1047,6 +1048,7 @@ set_tests_properties(test_row_conv_op PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_executor_seresnext_with_fuse_all_reduce_gpu PROPERTIES TIMEOUT 120) set_tests_properties(test_distributed_fused_lamb_op_with_clip PROPERTIES TIMEOUT 120) set_tests_properties(test_distributed_fused_lamb_op_without_clip PROPERTIES TIMEOUT 120) +set_tests_properties(test_distributed_fused_lamb_op_with_gradient_merge PROPERTIES TIMEOUT 120) set_tests_properties(test_elementwise_min_op PROPERTIES TIMEOUT 120) set_tests_properties(test_nan_inf PROPERTIES TIMEOUT 120) set_tests_properties(test_deformable_conv_v1_op PROPERTIES TIMEOUT 300) diff --git a/python/paddle/fluid/tests/unittests/distributed_fused_lamb_test_base.py b/python/paddle/fluid/tests/unittests/distributed_fused_lamb_test_base.py index 00d2a1f71d..0af7d40a2f 100644 --- a/python/paddle/fluid/tests/unittests/distributed_fused_lamb_test_base.py +++ b/python/paddle/fluid/tests/unittests/distributed_fused_lamb_test_base.py @@ -149,6 +149,7 @@ def run_model(use_distributed_lamb, use_fp16, use_master_param_norm, **kwargs): kwargs['exclude_from_weight_decay_fn'] = exclude_fn kwargs['lamb_weight_decay'] = 0.1 + gm_steps = kwargs['gradient_accumulation_steps'] if use_distributed_lamb: optimizer_class = DistributedFusedLamb kwargs = dict(kwargs) @@ -163,6 +164,7 @@ def run_model(use_distributed_lamb, use_fp16, use_master_param_norm, **kwargs): ) kwargs['grad_clip'] = GradClipDecorator(base_clip, clip_after_allreduce) + kwargs.pop('gradient_accumulation_steps', None) optimizer = optimizer_class(**kwargs) get_parameter = optimizer._get_parameter @@ -173,6 +175,7 @@ def run_model(use_distributed_lamb, use_fp16, use_master_param_norm, **kwargs): if use_fp16: if not use_distributed_lamb: optimizer._multi_precision = True + optimizer = paddle.static.amp.decorate( optimizer, amp_list, @@ -180,6 +183,13 @@ def run_model(use_distributed_lamb, use_fp16, use_master_param_norm, **kwargs): use_dynamic_loss_scaling=False, use_pure_fp16=use_fp16, use_fp16_guard=use_fp16) + amp_init = optimizer.amp_init + else: + amp_init = None + + if gm_steps > 1 and not use_distributed_lamb: + optimizer = paddle.fluid.optimizer.GradientMergeOptimizer( + optimizer, k_steps=gm_steps, avg=False) params_grads = optimizer.backward(loss, startup) op_num = len(main.global_block().ops) @@ -211,7 +221,7 @@ def run_model(use_distributed_lamb, use_fp16, use_master_param_norm, **kwargs): return grad_t def reader(): - for _ in range(5): + for _ in range(6): yield dict( [(grad.name, gen_random_grad_tensor(grad)) for grad in grads]) @@ -223,8 +233,8 @@ def run_model(use_distributed_lamb, use_fp16, use_master_param_norm, **kwargs): place = paddle.CUDAPlace(dev_id) exe = paddle.static.Executor(place) exe.run(startup) - if use_fp16: - optimizer.amp_init(place) + if amp_init is not None: + amp_init(place) master_p_ts = [] for p in params: @@ -258,10 +268,12 @@ class TestDistributedFusedLamb(unittest.TestCase): distutils.util.strtobool( os.getenv('CLIP_AFTER_ALLREDUCE', 'True'))) max_global_norm = float(os.getenv('MAX_GLOBAL_NORM', -1.0)) + gm_steps = int(os.getenv('GRADIENT_MERGE_STEPS', 1)) print('clip_after_allreduce = {}, max_global_norm = {}'.format( clip_after_allreduce, max_global_norm)) return { 'clip_after_allreduce': clip_after_allreduce, + 'gradient_accumulation_steps': gm_steps, 'grad_clip': paddle.nn.ClipGradByGlobalNorm(max_global_norm) if max_global_norm > 0 else None, } diff --git a/python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_clip.py b/python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_clip.py index af99529adf..315580dd31 100644 --- a/python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_clip.py +++ b/python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_clip.py @@ -34,7 +34,9 @@ def remove_file_if_exists(file_name): shutil.rmtree(file_name) -def run_test(clip_after_allreduce=True, max_global_norm=-1.0): +def run_test(clip_after_allreduce=True, + max_global_norm=-1.0, + gradient_merge_steps=1): if not paddle.is_compiled_with_cuda(): return if os.name == 'nt': @@ -55,6 +57,7 @@ def run_test(clip_after_allreduce=True, max_global_norm=-1.0): os.environ['CLIP_AFTER_ALLREDUCE'] = str(clip_after_allreduce) os.environ['MAX_GLOBAL_NORM'] = str(max_global_norm) + os.environ['GRADIENT_MERGE_STEPS'] = str(gradient_merge_steps) touch_file_env = 'SUCCESS_TOUCH_FILE' touch_file_name = 'distributed_fused_lamb_touch_file_{}'.format(os.getpid()) diff --git a/python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_gradient_merge.py b/python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_gradient_merge.py new file mode 100644 index 0000000000..1822b77d0d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_gradient_merge.py @@ -0,0 +1,28 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test_distributed_fused_lamb_op_with_clip import run_test +import unittest + + +class TestDistributedFusedLambGradientMerge(unittest.TestCase): + def test_gm(self): + run_test( + clip_after_allreduce=True, + max_global_norm=-1.0, + gradient_merge_steps=2) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/optimizer/distributed_fused_lamb.py b/python/paddle/incubate/optimizer/distributed_fused_lamb.py index 74b5398230..4d40a477ff 100644 --- a/python/paddle/incubate/optimizer/distributed_fused_lamb.py +++ b/python/paddle/incubate/optimizer/distributed_fused_lamb.py @@ -38,6 +38,7 @@ class DistributedFusedLamb(Optimizer): is_grad_scaled_by_nranks=True, alignment=128, use_master_param_norm=True, + gradient_accumulation_steps=1, name=None): assert not framework._non_static_mode( ), "DistributedFusedLamb does not support dygraph mode" @@ -63,6 +64,9 @@ class DistributedFusedLamb(Optimizer): self._scale = None self._ring_id = 0 self._use_master_param_norm = use_master_param_norm + self._gradient_accumulation_steps = gradient_accumulation_steps + assert self._gradient_accumulation_steps >= 1 + self.helper = LayerHelper('distributed_fused_lamb') self._supports_check_nan_inf = True # very import flag for AMP @@ -73,8 +77,19 @@ class DistributedFusedLamb(Optimizer): dtype=core.VarDesc.VarType.BOOL) self._step = None + if self._gradient_accumulation_steps > 1: + self._stop_update = main_block.create_var( + name=unique_name.generate('stop_update'), + shape=[1], + dtype=core.VarDesc.VarType.BOOL) + else: + self._stop_update = None + self._param_to_master_param = {} + def _get_stop_update_var(self): + return self._stop_update if self._stop_update is not None else False + def _set_step(self, step): self._step = step @@ -194,6 +209,20 @@ class DistributedFusedLamb(Optimizer): param_order = self._create_persistable_var('param_order', dtype='int32') param_order.is_distributed = True + if self._gradient_accumulation_steps > 1: + fp32_acc_fused_grad = [ + self._create_persistable_var('fp32_acc_fused_grad') + ] + fp16_acc_fused_grad = [ + self._create_persistable_var( + 'fp16_acc_fused_grad', dtype='float16') + ] + acc_step = [self._create_persistable_var('acc_step', dtype='int64')] + else: + fp32_acc_fused_grad = [] + fp16_acc_fused_grad = [] + acc_step = [] + step = self._get_or_create_step() rank = get_rank() @@ -298,6 +327,11 @@ class DistributedFusedLamb(Optimizer): 'ParamOut': params, 'GradOut': grads, 'FoundInf': [self._found_inf], + 'FP32AccFusedGrad': fp32_acc_fused_grad, + 'FP16AccFusedGrad': fp16_acc_fused_grad, + 'AccStep': acc_step, + 'StopUpdate': self._stop_update + if self._stop_update is not None else [], 'Step': [step], }, attrs={ @@ -311,5 +345,6 @@ class DistributedFusedLamb(Optimizer): 'ring_id': self._ring_id, 'use_master_param_norm': self._use_master_param_norm, 'is_grad_scaled_by_nranks': self._is_grad_scaled_by_nranks, + 'acc_steps': self._gradient_accumulation_steps, }) return [lamb_op] -- GitLab