From c770053cb2230c1893a4d4995d45b95183a297d1 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 3 Aug 2022 10:41:10 +0800 Subject: [PATCH] Add use_hierarchical_allreduce for DistributedFusedLAMB (#44821) * add use_hierarchical_allreduce * support hierarchical allreduce for more cases --- .../optimizers/distributed_fused_lamb_op.cc | 3 + .../optimizers/distributed_fused_lamb_op.cu | 206 ++++++++++++++---- .../optimizer/distributed_fused_lamb.py | 12 + 3 files changed, 175 insertions(+), 46 deletions(-) diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc index b85eb16a39..9f286fef47 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc @@ -152,6 +152,9 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>("ring_id", "The ring id of the NCCL communicator.") .SetDefault({0}); + AddAttr("use_hierarchical_allreduce", + "Whether to use hierarchical allreduce") + .SetDefault(false); AddComment("The DistributedFusedLamb optimizer."); } }; diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu index 394e49dd52..f1b852301d 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu @@ -1614,15 +1614,20 @@ class DistributedFusedLambOpKernel const auto &ring_ids = ctx.Attr>("ring_id"); auto use_master_param_norm = ctx.Attr("use_master_param_norm"); auto is_grad_scaled_by_nranks = ctx.Attr("is_grad_scaled_by_nranks"); + auto use_hierarchical_allreduce = + ctx.Attr("use_hierarchical_allreduce"); VLOG(10) << "max_global_grad_norm = " << max_global_grad_norm << " , clip_after_allreduce = " << clip_after_allreduce << " , use_master_param_norm = " << use_master_param_norm << " , is_grad_scaled_by_nranks = " << is_grad_scaled_by_nranks - << " , local_shard = " << local_shard; + << " , local_shard = " << local_shard + << " , use_hierarchical_allreduce = " + << use_hierarchical_allreduce; // Step 6: allreduce + global norm gradient clip int64_t global_rank = 0, local_rank = 0; - ncclComm_t global_comm = nullptr, local_comm = nullptr; + ncclComm_t global_comm = nullptr, local_comm = nullptr, + external_comm = nullptr; if (nranks > 1) { auto *nccl_comm_handle = platform::NCCLCommContext::Instance().Get(ring_ids[0], place); @@ -1634,6 +1639,11 @@ class DistributedFusedLambOpKernel platform::NCCLCommContext::Instance().Get(ring_ids[1], place); local_comm = local_nccl_comm_handle->comm(); local_rank = local_nccl_comm_handle->rank(); + if (use_hierarchical_allreduce) { + external_comm = platform::NCCLCommContext::Instance() + .Get(ring_ids[2], place) + ->comm(); + } } else { local_comm = global_comm; local_rank = global_rank; @@ -1686,20 +1696,54 @@ class DistributedFusedLambOpKernel if (clip_after_allreduce) { // (1) ReduceScater first if (local_shard) { - NCCLAllReduceWithScale(fp32_grad, - fp32_sum_grad, - fp32_numel, - nranks, - global_comm, - stream, - dev_ctx); - NCCLAllReduceWithScale(fp16_grad, - fp16_sum_grad, - fp16_numel, - nranks, - global_comm, - stream, - dev_ctx); + if (use_hierarchical_allreduce) { + NCCLAllReduceWithScale(fp32_grad, + fp32_sum_grad, + fp32_numel, + nranks / num_devices, + external_comm, + stream, + dev_ctx); + NCCLReduceScatterWithScale( + fp32_sum_grad, + fp32_sum_grad + local_rank * fp32_numel_each_device, + fp32_numel_each_device, + num_devices, + local_comm, + stream, + dev_ctx); + + NCCLAllReduceWithScale(fp16_grad, + fp16_sum_grad, + fp16_numel, + nranks / num_devices, + external_comm, + stream, + dev_ctx); + NCCLReduceScatterWithScale( + fp16_sum_grad, + fp16_sum_grad + local_rank * fp16_numel_each_device, + fp16_numel_each_device, + num_devices, + local_comm, + stream, + dev_ctx); + } else { + NCCLAllReduceWithScale(fp32_grad, + fp32_sum_grad, + fp32_numel, + nranks, + global_comm, + stream, + dev_ctx); + NCCLAllReduceWithScale(fp16_grad, + fp16_sum_grad, + fp16_numel, + nranks, + global_comm, + stream, + dev_ctx); + } fp32_sum_grad += (local_rank * fp32_numel_each_device); fp16_sum_grad += (local_rank * fp16_numel_each_device); } else { @@ -1794,22 +1838,58 @@ class DistributedFusedLambOpKernel } // (3) Do ReduceScatter with scale if (local_shard) { - NCCLAllReduceWithScale(fp32_grad, - fp32_sum_grad, - fp32_numel, - nranks, - global_comm, - stream, - dev_ctx, - fp32_scale); - NCCLAllReduceWithScale(fp16_grad, - fp16_sum_grad, - fp16_numel, - nranks, - global_comm, - stream, - dev_ctx, - fp16_scale); + if (use_hierarchical_allreduce) { + NCCLAllReduceWithScale(fp32_grad, + fp32_sum_grad, + fp32_numel, + nranks / num_devices, + external_comm, + stream, + dev_ctx, + fp32_scale); + NCCLReduceScatterWithScale( + fp32_sum_grad, + fp32_sum_grad + local_rank * fp32_numel_each_device, + fp32_numel_each_device, + num_devices, + local_comm, + stream, + dev_ctx); + + NCCLAllReduceWithScale(fp16_grad, + fp16_sum_grad, + fp16_numel, + nranks / num_devices, + external_comm, + stream, + dev_ctx, + fp16_scale); + NCCLReduceScatterWithScale( + fp16_sum_grad, + fp16_sum_grad + local_rank * fp16_numel_each_device, + fp16_numel_each_device, + num_devices, + local_comm, + stream, + dev_ctx); + } else { + NCCLAllReduceWithScale(fp32_grad, + fp32_sum_grad, + fp32_numel, + nranks, + global_comm, + stream, + dev_ctx, + fp32_scale); + NCCLAllReduceWithScale(fp16_grad, + fp16_sum_grad, + fp16_numel, + nranks, + global_comm, + stream, + dev_ctx, + fp16_scale); + } fp32_sum_grad += (local_rank * fp32_numel_each_device); fp16_sum_grad += (local_rank * fp16_numel_each_device); } else { @@ -1836,20 +1916,54 @@ class DistributedFusedLambOpKernel } } else { if (local_shard) { - NCCLAllReduceWithScale(fp32_grad, - fp32_sum_grad, - fp32_numel, - nranks, - global_comm, - stream, - dev_ctx); - NCCLAllReduceWithScale(fp16_grad, - fp16_sum_grad, - fp16_numel, - nranks, - global_comm, - stream, - dev_ctx); + if (use_hierarchical_allreduce) { + NCCLAllReduceWithScale(fp32_grad, + fp32_sum_grad, + fp32_numel, + nranks / num_devices, + external_comm, + stream, + dev_ctx); + NCCLReduceScatterWithScale( + fp32_sum_grad, + fp32_sum_grad + local_rank * fp32_numel_each_device, + fp32_numel_each_device, + num_devices, + local_comm, + stream, + dev_ctx); + + NCCLAllReduceWithScale(fp16_grad, + fp16_sum_grad, + fp16_numel, + nranks / num_devices, + external_comm, + stream, + dev_ctx); + NCCLReduceScatterWithScale( + fp16_sum_grad, + fp16_sum_grad + local_rank * fp16_numel_each_device, + fp16_numel_each_device, + num_devices, + local_comm, + stream, + dev_ctx); + } else { + NCCLAllReduceWithScale(fp32_grad, + fp32_sum_grad, + fp32_numel, + nranks, + global_comm, + stream, + dev_ctx); + NCCLAllReduceWithScale(fp16_grad, + fp16_sum_grad, + fp16_numel, + nranks, + global_comm, + stream, + dev_ctx); + } fp32_sum_grad += (local_rank * fp32_numel_each_device); fp16_sum_grad += (local_rank * fp16_numel_each_device); } else { diff --git a/python/paddle/incubate/optimizer/distributed_fused_lamb.py b/python/paddle/incubate/optimizer/distributed_fused_lamb.py index d283eae392..d230b6afca 100644 --- a/python/paddle/incubate/optimizer/distributed_fused_lamb.py +++ b/python/paddle/incubate/optimizer/distributed_fused_lamb.py @@ -101,6 +101,7 @@ class DistributedFusedLamb(Optimizer): gradient_accumulation_steps=1, use_master_acc_grad=True, nproc_per_node=None, + use_hierarchical_allreduce=False, name=None): assert not framework._non_static_mode( ), "DistributedFusedLamb does not support dygraph mode" @@ -129,6 +130,7 @@ class DistributedFusedLamb(Optimizer): self._gradient_accumulation_steps = gradient_accumulation_steps self._use_master_acc_grad = use_master_acc_grad self._nproc_per_node = nproc_per_node + self._use_hierarchical_allreduce = use_hierarchical_allreduce assert self._gradient_accumulation_steps >= 1 self.helper = LayerHelper('distributed_fused_lamb') @@ -305,6 +307,7 @@ class DistributedFusedLamb(Optimizer): list(range(nranks)), 0) ring_ids.append(ring_id) + use_hierarchical_allreduce = False if node_num > 1 and len(ring_ids) <= 1 and shard_inside_node: local_group_ranks = list( range(node_id * nproc_per_node, (node_id + 1) * nproc_per_node)) @@ -312,6 +315,14 @@ class DistributedFusedLamb(Optimizer): 1) ring_ids.append(ring_id) + if self._use_hierarchical_allreduce and nranks > nproc_per_node: + use_hierarchical_allreduce = True + outer_group_ranks = list( + range(rank % nproc_per_node, nranks, nproc_per_node)) + ring_id = init_communicator(startup_block, rank, + outer_group_ranks, ring_ids[-1] + 1) + ring_ids.append(ring_id) + scale = self._get_or_create_scale() params = [p for p, _ in params_grads] @@ -439,5 +450,6 @@ class DistributedFusedLamb(Optimizer): 'is_grad_scaled_by_nranks': self._is_grad_scaled_by_nranks, 'acc_steps': self._gradient_accumulation_steps, 'use_master_acc_grad': self._use_master_acc_grad, + 'use_hierarchical_allreduce': use_hierarchical_allreduce, }) return [lamb_op] -- GitLab