From 6678def99ae7c9b94286eb59ab9ef995a8563ab1 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 9 Jun 2022 16:12:28 +0800 Subject: [PATCH] Add nproc_per_node for DistributedFusedLamb (#43295) * add nproc_per_node for DistributedFusedLamb * fix nproc_per_node communicator bug * fix ring_id = 1 init bug * fix ci * fix test_parallel_executor_mnist.py --- paddle/fluid/framework/parallel_executor.cc | 27 ++- .../optimizers/distributed_fused_lamb_op.cc | 6 +- .../optimizers/distributed_fused_lamb_op.cu | 185 +++++++++++++----- .../graph_execution_optimizer.py | 3 + .../optimizer/distributed_fused_lamb.py | 169 ++++++++++++---- 5 files changed, 291 insertions(+), 99 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 00d48098a1..6f8621d30e 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -793,8 +793,8 @@ void ParallelExecutor::BCastParamsToDevices( std::vector buffers; buffers.reserve(member_->places_.size()); size_t numel = main_tensor.numel(); - ncclDataType_t data_type = platform::ToNCCLDataType( - framework::TransToProtoVarType(main_tensor.dtype())); + auto dtype = framework::TransToProtoVarType(main_tensor.dtype()); + ncclDataType_t data_type = platform::ToNCCLDataType(dtype); for (size_t i = 0; i < member_->places_.size(); ++i) { auto place = member_->places_[i]; void *buffer; @@ -815,7 +815,7 @@ void ParallelExecutor::BCastParamsToDevices( "variables' buffer size to bcast is %d, which is " "NOT equal to places size %d", buffers.size(), member_->places_.size())); - { + if (member_->nccl_ctxs_ != nullptr) { auto *nccl_ctxs = member_->nccl_ctxs_->DefaultFlatCtx(); platform::NCCLGroupGuard guard; for (size_t i = 0; i < member_->places_.size(); ++i) { @@ -824,6 +824,22 @@ void ParallelExecutor::BCastParamsToDevices( nccl_ctx.comm_, nccl_ctx.stream()); } nccl_ctxs->WaitAll(); + } else { + auto src_place = member_->places_[0]; + auto src_dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(src_place)); + auto sizeof_dtype = framework::SizeOfType(dtype) * numel; + for (size_t i = 1; i < member_->places_.size(); ++i) { + auto dst_place = member_->places_[i]; + auto dst_dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(dst_place)); + src_dev_ctx->Wait(); + dst_dev_ctx->Wait(); + memory::Copy(dst_place, buffers[i], src_place, buffers[0], + sizeof_dtype, src_dev_ctx->stream()); + src_dev_ctx->Wait(); + dst_dev_ctx->Wait(); + } } #endif } else if (paddle::platform::is_xpu_place(main_tensor.place())) { @@ -1348,6 +1364,11 @@ std::vector ParallelExecutor::CloneGraphToMultiDevices( } void ParallelExecutor::PrepareNCCLCommunicator(Scope *global_scope) { + if (member_->build_strategy_.reduce_ == + BuildStrategy::ReduceStrategy::kNoReduce) { + return; + } + if (member_->IsUseCUDA(member_->use_device_) && member_->nranks_ > 1) { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) member_->InitOrGetNCCLCommunicator(global_scope, &member_->build_strategy_); diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc index 0f9bcc4c2d..9b8e67eb6f 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc @@ -147,8 +147,10 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("is_grad_scaled_by_nranks", "Whether the input gradient has been scaled by nranks.") .SetDefault(true); - AddAttr("ring_id", "The ring id of the NCCL communicator.") - .SetDefault(0); + AddAttr("nranks", "The world size.").SetDefault(1); + AddAttr>("ring_id", + "The ring id of the NCCL communicator.") + .SetDefault({0}); AddComment("The DistributedFusedLamb optimizer."); } }; diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu index e7f6223968..91c583fff8 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu @@ -806,23 +806,24 @@ static void LaunchScaleKernel(const platform::CUDADeviceContext &dev_ctx, #undef PD_LAMB_VEC_SCALE_KERNEL_CASE } -template -static void NCCLReduceScatterWithScale( - const T *sendbuff, T *recvbuff, size_t recvcount, size_t nranks, - ncclComm_t comm, gpuStream_t stream, - const platform::CUDADeviceContext &dev_ctx, const T *scale = nullptr) { +template +static void NCCLSumWithScaleBase(const T *sendbuff, T *recvbuff, + size_t recvcount, size_t nranks, + ncclComm_t comm, gpuStream_t stream, + const platform::CUDADeviceContext &dev_ctx, + const T *scale = nullptr) { static_assert(std::is_same::value || std::is_same::value, "T must be either float32 or float16."); if (recvcount == 0) return; + auto numel = UseReduceScatter ? (recvcount * nranks) : recvcount; if (comm == nullptr) { if (scale != nullptr) { PADDLE_ENFORCE_EQ(nranks, 1, platform::errors::InvalidArgument( "nranks must be 1 when scale != nullptr.")); - LaunchScaleKernel(dev_ctx, sendbuff, scale, recvbuff, recvcount * nranks, - stream); + LaunchScaleKernel(dev_ctx, sendbuff, scale, recvbuff, numel, stream); } return; } @@ -834,14 +835,18 @@ static void NCCLReduceScatterWithScale( scale && CreatePreMulScaleOpIfSupported(dtype, comm, scale, &op); memory::Buffer buffer(dev_ctx.GetPlace()); if (scale && !should_destroy_op) { - size_t numel = recvcount * nranks; T *new_sendbuff = buffer.Alloc(numel); LaunchScaleKernel(dev_ctx, sendbuff, scale, new_sendbuff, numel, stream); sendbuff = new_sendbuff; } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter( - sendbuff, recvbuff, recvcount, dtype, op, comm, stream)); + if (UseReduceScatter) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter( + sendbuff, recvbuff, recvcount, dtype, op, comm, stream)); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + sendbuff, recvbuff, recvcount, dtype, op, comm, stream)); + } #if NCCL_VERSION_CODE >= 21100 if (should_destroy_op) { @@ -851,6 +856,26 @@ static void NCCLReduceScatterWithScale( } #endif } + +template +static void NCCLReduceScatterWithScale( + const T *sendbuff, T *recvbuff, size_t recvcount, size_t nranks, + ncclComm_t comm, gpuStream_t stream, + const platform::CUDADeviceContext &dev_ctx, const T *scale = nullptr) { + NCCLSumWithScaleBase(sendbuff, recvbuff, recvcount, nranks, comm, + stream, dev_ctx, scale); +} + +template +static void NCCLAllReduceWithScale(const T *sendbuff, T *recvbuff, + size_t recvcount, size_t nranks, + ncclComm_t comm, gpuStream_t stream, + const platform::CUDADeviceContext &dev_ctx, + const T *scale = nullptr) { + NCCLSumWithScaleBase(sendbuff, recvbuff, recvcount, nranks, comm, + stream, dev_ctx, scale); +} + #endif template "exactly by the element number %d of Moment1.", numel, partial_numel)); + // The num_devices means the number of devices that shard a complete set + // of all parameters. It may be num_devices < nranks or num_devices == + // nranks. int64_t num_devices = numel / partial_numel; VLOG(1) << "num_devices = " << num_devices << " , partial_numel = " << partial_numel; @@ -1354,22 +1382,43 @@ class DistributedFusedLambOpKernel auto epsilon = ctx.Attr("epsilon"); auto max_global_grad_norm = ctx.Attr("max_global_grad_norm"); auto clip_after_allreduce = ctx.Attr("clip_after_allreduce"); - auto ring_id = ctx.Attr("ring_id"); + auto nranks = ctx.Attr("nranks"); + PADDLE_ENFORCE_GE(nranks, num_devices, + phi::errors::InvalidArgument( + "The nranks must be not less than num_devices.")); + PADDLE_ENFORCE_EQ( + nranks % num_devices, 0, + phi::errors::InvalidArgument( + "The nranks must be exactly divided by num_devices.")); + bool local_shard = (nranks > num_devices); + + const auto &ring_ids = ctx.Attr>("ring_id"); auto use_master_param_norm = ctx.Attr("use_master_param_norm"); auto is_grad_scaled_by_nranks = ctx.Attr("is_grad_scaled_by_nranks"); VLOG(10) << "max_global_grad_norm = " << max_global_grad_norm << " , clip_after_allreduce = " << clip_after_allreduce << " , use_master_param_norm = " << use_master_param_norm - << " , is_grad_scaled_by_nranks = " << is_grad_scaled_by_nranks; + << " , is_grad_scaled_by_nranks = " << is_grad_scaled_by_nranks + << " , local_shard = " << local_shard; // Step 6: allreduce + global norm gradient clip - int rank = 0; - ncclComm_t comm = nullptr; - if (num_devices > 1) { + int64_t global_rank = 0, local_rank = 0; + ncclComm_t global_comm = nullptr, local_comm = 0; + if (nranks > 1) { auto *nccl_comm_handle = - platform::NCCLCommContext::Instance().Get(ring_id, place); - comm = nccl_comm_handle->comm(); - rank = nccl_comm_handle->rank(); + platform::NCCLCommContext::Instance().Get(ring_ids[0], place); + global_comm = nccl_comm_handle->comm(); + global_rank = nccl_comm_handle->rank(); + + if (local_shard) { + auto *local_nccl_comm_handle = + platform::NCCLCommContext::Instance().Get(ring_ids[1], place); + local_comm = local_nccl_comm_handle->comm(); + local_rank = local_nccl_comm_handle->rank(); + } else { + local_comm = global_comm; + local_rank = global_rank; + } } memory::Buffer grad_norm_square_buffer(place); @@ -1381,8 +1430,15 @@ class DistributedFusedLambOpKernel platform::float16 *fp16_sum_grad; auto fp32_numel_each_device = fp32_numel / num_devices; auto fp16_numel_each_device = fp16_numel / num_devices; - if (num_devices > 1 || - (max_global_grad_norm > 0 && !clip_after_allreduce)) { + if (local_shard) { + auto ptr = sum_grad_buffer.Alloc( + fp32_numel * sizeof(float) + fp16_numel * sizeof(platform::float16)); + fp32_sum_grad = has_fp32_param ? reinterpret_cast(ptr) : nullptr; + fp16_sum_grad = has_fp16_param ? reinterpret_cast( + ptr + fp32_numel * sizeof(float)) + : nullptr; + } else if (nranks > 1 || + (max_global_grad_norm > 0 && !clip_after_allreduce)) { auto ptr = sum_grad_buffer.Alloc( fp32_numel_each_device * sizeof(float) + fp16_numel_each_device * sizeof(platform::float16)); @@ -1404,18 +1460,27 @@ class DistributedFusedLambOpKernel float rescale_grad = 1.0f; if (!is_grad_scaled_by_nranks) { - rescale_grad /= num_devices; + rescale_grad /= nranks; } if (max_global_grad_norm > 0) { if (clip_after_allreduce) { // (1) ReduceScater first - NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad, - fp32_numel_each_device, num_devices, comm, - stream, dev_ctx); - NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad, - fp16_numel_each_device, num_devices, comm, - stream, dev_ctx); + if (local_shard) { + NCCLAllReduceWithScale(fp32_grad, fp32_sum_grad, fp32_numel, nranks, + global_comm, stream, dev_ctx); + NCCLAllReduceWithScale(fp16_grad, fp16_sum_grad, fp16_numel, nranks, + global_comm, stream, dev_ctx); + fp32_sum_grad += (local_rank * fp32_numel_each_device); + fp16_sum_grad += (local_rank * fp16_numel_each_device); + } else { + NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad, + fp32_numel_each_device, nranks, + global_comm, stream, dev_ctx); + NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad, + fp16_numel_each_device, nranks, + global_comm, stream, dev_ctx); + } // (2) Calculate the global grad norm GetSquareGradNorm(fp32_sum_grad, fp32_numel_each_device, fp16_sum_grad, fp16_numel_each_device, fp32_square_grad_norm, stream, @@ -1425,7 +1490,7 @@ class DistributedFusedLambOpKernel if (num_devices > 1) { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32, - ncclSum, comm, stream)); + ncclSum, local_comm, stream)); } VLOG(1) << "Grad square norm after all reduce: " << FlattenToString(fp32_square_grad_norm, 1, place); @@ -1452,7 +1517,7 @@ class DistributedFusedLambOpKernel float clip_scale = 1.0f; if (is_grad_scaled_by_nranks) { - clip_scale *= num_devices; + clip_scale *= nranks; } CalcGradNormClipBeforeAllReduceScale <<<1, 1, 0, stream>>>(global_scale, max_global_grad_norm, @@ -1463,36 +1528,54 @@ class DistributedFusedLambOpKernel } else { VLOG(1) << "Grad scale: " << FlattenToString(fp16_scale, 1, place); } - if (num_devices > 1) { + if (nranks > 1) { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32, - ncclSum, comm, stream)); + ncclSum, global_comm, stream)); } // (3) Do ReduceScatter with scale - NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad, - fp32_numel_each_device, num_devices, comm, - stream, dev_ctx, fp32_scale); - NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad, - fp16_numel_each_device, num_devices, comm, - stream, dev_ctx, fp16_scale); + if (local_shard) { + NCCLAllReduceWithScale(fp32_grad, fp32_sum_grad, fp32_numel, nranks, + global_comm, stream, dev_ctx, fp32_scale); + NCCLAllReduceWithScale(fp16_grad, fp16_sum_grad, fp16_numel, nranks, + global_comm, stream, dev_ctx, fp16_scale); + fp32_sum_grad += (local_rank * fp32_numel_each_device); + fp16_sum_grad += (local_rank * fp16_numel_each_device); + } else { + NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad, + fp32_numel_each_device, nranks, + global_comm, stream, dev_ctx, fp32_scale); + NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad, + fp16_numel_each_device, nranks, + global_comm, stream, dev_ctx, fp16_scale); + } // (4) mark max_global_grad_norm as 0, meaning that clip has been // already performed max_global_grad_norm = 0; } } else { - NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad, - fp32_numel_each_device, num_devices, comm, - stream, dev_ctx); - NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad, - fp16_numel_each_device, num_devices, comm, - stream, dev_ctx); + if (local_shard) { + NCCLAllReduceWithScale(fp32_grad, fp32_sum_grad, fp32_numel, nranks, + global_comm, stream, dev_ctx); + NCCLAllReduceWithScale(fp16_grad, fp16_sum_grad, fp16_numel, nranks, + global_comm, stream, dev_ctx); + fp32_sum_grad += (local_rank * fp32_numel_each_device); + fp16_sum_grad += (local_rank * fp16_numel_each_device); + } else { + NCCLReduceScatterWithScale(fp32_grad, fp32_sum_grad, + fp32_numel_each_device, num_devices, + global_comm, stream, dev_ctx); + NCCLReduceScatterWithScale(fp16_grad, fp16_sum_grad, + fp16_numel_each_device, num_devices, + global_comm, stream, dev_ctx); + } CheckHasNanInfGrad(fp32_sum_grad, fp32_numel_each_device, fp16_sum_grad, fp16_numel_each_device, fp32_square_grad_norm, stream, &cub_tmp_buffer); if (num_devices > 1) { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( fp32_square_grad_norm, fp32_square_grad_norm, 1, ncclFloat32, - ncclSum, comm, stream)); + ncclSum, local_comm, stream)); } max_global_grad_norm = 0; } @@ -1526,8 +1609,8 @@ class DistributedFusedLambOpKernel memory::Buffer trust_ratio_div_buffer(place); auto *trust_ratio_div = trust_ratio_div_buffer.Alloc(partial_numel); - auto fp32_offset = rank * fp32_numel_each_device; - auto fp16_offset = rank * fp16_numel_each_device; + auto fp32_offset = local_rank * fp32_numel_each_device; + auto fp16_offset = local_rank * fp16_numel_each_device; if (has_fp32_param) { VLOG(10) << "Update FP32 Moment and TrustRatioDiv starts"; MultiTensorUpdateLambMomentAndTrustRatioDiv( @@ -1598,12 +1681,12 @@ class DistributedFusedLambOpKernel PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( param_square_norm + fp32_global_param_num, param_square_norm + fp32_global_param_num, - 2 * param_num - fp32_global_param_num, ncclFloat32, ncclSum, comm, - stream)); + 2 * param_num - fp32_global_param_num, ncclFloat32, ncclSum, + local_comm, stream)); } else { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( trust_ratio_div_square_norm, trust_ratio_div_square_norm, param_num, - ncclFloat32, ncclSum, comm, stream)); + ncclFloat32, ncclSum, local_comm, stream)); } VLOG(10) << "ncclAllReduce done"; } @@ -1623,7 +1706,7 @@ class DistributedFusedLambOpKernel // ncclAllGather PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( fp32_param + fp32_offset, fp32_param, fp32_numel_each_device, - ncclFloat32, comm, stream)); + ncclFloat32, local_comm, stream)); } beta1pow = nullptr; @@ -1641,7 +1724,7 @@ class DistributedFusedLambOpKernel // ncclAllGather PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllGather( fp16_param + fp16_offset, fp16_param, fp16_numel_each_device, - ncclFloat16, comm, stream)); + ncclFloat16, local_comm, stream)); } } VLOG(10) << "Update Param done"; diff --git a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py index 8f42553048..a5b0856a66 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py @@ -69,6 +69,9 @@ class GraphExecutionOptimizer(MetaOptimizerBase): if trainer_id == 0 and not paddle.is_compiled_with_npu(): wait_server_ready(other_trainers) + if build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy._NoReduce: + return + if core.is_compiled_with_cuda(): comm_id_var = startup_program.global_block().create_var( name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW) diff --git a/python/paddle/incubate/optimizer/distributed_fused_lamb.py b/python/paddle/incubate/optimizer/distributed_fused_lamb.py index 3029c3a294..d283eae392 100644 --- a/python/paddle/incubate/optimizer/distributed_fused_lamb.py +++ b/python/paddle/incubate/optimizer/distributed_fused_lamb.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from paddle.fluid import framework, core, layers, unique_name from paddle.fluid.framework import Variable from paddle.fluid.clip import ClipGradByGlobalNorm @@ -19,11 +20,69 @@ from paddle.fluid.initializer import Constant from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.optimizer import Optimizer from paddle.distributed import get_rank, get_world_size +from paddle.distributed.collective import new_group from paddle.fluid.executor import global_scope from paddle.fluid.framework import name_scope +from paddle.fluid import core, unique_name import numpy as np +def init_communicator(block, rank, ranks, ring_id): + eps = os.environ['PADDLE_TRAINER_ENDPOINTS'] + eps = [ep.strip() for ep in eps.split(",") if ep.strip()] + cur_ep = eps[rank] + other_eps = [eps[r] for r in ranks if r != rank] + + local_rank = ranks.index(rank) + comm_var_name = unique_name.generate('comm_id') + comm_id_var = block.create_var(name=comm_var_name, + persistable=True, + type=core.VarDesc.VarType.RAW) + block.append_op(type='c_gen_nccl_id', + inputs={}, + outputs={'Out': comm_id_var}, + attrs={ + 'rank': local_rank, + 'endpoint': cur_ep, + 'other_endpoints': other_eps, + 'ring_id': ring_id + }) + block.append_op(type='c_comm_init', + inputs={'X': comm_id_var}, + outputs={}, + attrs={ + 'nranks': len(ranks), + 'rank': local_rank, + 'ring_id': ring_id + }) + tmp_var = block.create_var(name=unique_name.generate('tmp')) + block.append_op(type='fill_constant', + outputs={'Out': tmp_var}, + attrs={'value': 1}) + block.append_op(type='c_allreduce_sum', + inputs={'X': tmp_var}, + outputs={'Out': tmp_var}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': True + }) + block.append_op(type='c_sync_calc_stream', + inputs={'X': tmp_var}, + outputs={'Out': tmp_var}) + return ring_id + + +def broadcast_parameters(block, parameters, ring_id): + for p in parameters: + block.append_op(type='c_broadcast', + inputs={'X': p}, + outputs={'Out': p}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': True + }) + + class DistributedFusedLamb(Optimizer): def __init__(self, @@ -41,6 +100,7 @@ class DistributedFusedLamb(Optimizer): use_master_param_norm=True, gradient_accumulation_steps=1, use_master_acc_grad=True, + nproc_per_node=None, name=None): assert not framework._non_static_mode( ), "DistributedFusedLamb does not support dygraph mode" @@ -65,10 +125,10 @@ class DistributedFusedLamb(Optimizer): self._is_grad_scaled_by_nranks = is_grad_scaled_by_nranks self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn self._scale = None - self._ring_id = 0 self._use_master_param_norm = use_master_param_norm self._gradient_accumulation_steps = gradient_accumulation_steps self._use_master_acc_grad = use_master_acc_grad + self._nproc_per_node = nproc_per_node assert self._gradient_accumulation_steps >= 1 self.helper = LayerHelper('distributed_fused_lamb') @@ -228,6 +288,30 @@ class DistributedFusedLamb(Optimizer): rank = get_rank() nranks = get_world_size() + if self._nproc_per_node is None: + nproc_per_node = nranks + else: + nproc_per_node = self._nproc_per_node + assert nranks % nproc_per_node == 0, "nranks should be exactly divided by nproc_per_node" + + shard_inside_node = (nranks > nproc_per_node) + local_rank = rank % nproc_per_node + node_id = int(rank / nproc_per_node) + node_num = int(nranks / nproc_per_node) + ring_ids = [] + startup_block = self.helper.startup_program.global_block() + if nranks > 1: + ring_id = init_communicator(startup_block, rank, + list(range(nranks)), 0) + ring_ids.append(ring_id) + + if node_num > 1 and len(ring_ids) <= 1 and shard_inside_node: + local_group_ranks = list( + range(node_id * nproc_per_node, (node_id + 1) * nproc_per_node)) + ring_id = init_communicator(startup_block, rank, local_group_ranks, + 1) + ring_ids.append(ring_id) + scale = self._get_or_create_scale() params = [p for p, _ in params_grads] @@ -238,7 +322,6 @@ class DistributedFusedLamb(Optimizer): if self._exclude_from_weight_decay_fn(p): apply_weight_decay[i] = 0 - startup_block = self.helper.startup_program.global_block() for g in grads: startup_block.create_var(name=g.name, type=g.type, @@ -246,46 +329,45 @@ class DistributedFusedLamb(Optimizer): persistable=g.persistable, shape=g.shape) - startup_block.append_op(type='distributed_fused_lamb_init', - inputs={ - 'Param': params, - 'Grad': grads, - }, - outputs={ - 'FP32FusedParam': [fp32_fused_param], - 'FP32FusedGrad': [fp32_fused_grad], - 'FP16FusedParam': [fp16_fused_param], - 'FP16FusedGrad': [fp16_fused_grad], - 'Moment1': [moment1], - 'Moment2': [moment2], - 'Beta1Pow': [beta1pow], - 'Beta2Pow': [beta2pow], - 'GlobalScale': [scale], - 'ParamInfo': [param_info], - 'ParamOut': - params, - 'MasterParamOut': - master_params, - 'GradOut': - grads, - 'FP32ShardFusedParamOffsets': - [fp32_partial_fused_offsets], - 'FP16ShardFusedParamOffsets': - [fp16_partial_fused_offsets], - 'FusedParamOffsets': [fused_offsets], - 'ParamOrder': [param_order], - 'Step': [step], - }, - attrs={ - 'alignment': self._alignment, - 'rank': rank, - 'nranks': nranks, - 'apply_weight_decay': apply_weight_decay, - 'moment1': 0.0, - 'moment2': 0.0, - 'beta1': self._beta1, - 'beta2': self._beta2, - }) + if nranks > 1: + broadcast_parameters(startup_block, params, ring_ids[0]) + + startup_block.append_op( + type='distributed_fused_lamb_init', + inputs={ + 'Param': params, + 'Grad': grads, + }, + outputs={ + 'FP32FusedParam': [fp32_fused_param], + 'FP32FusedGrad': [fp32_fused_grad], + 'FP16FusedParam': [fp16_fused_param], + 'FP16FusedGrad': [fp16_fused_grad], + 'Moment1': [moment1], + 'Moment2': [moment2], + 'Beta1Pow': [beta1pow], + 'Beta2Pow': [beta2pow], + 'GlobalScale': [scale], + 'ParamInfo': [param_info], + 'ParamOut': params, + 'MasterParamOut': master_params, + 'GradOut': grads, + 'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets], + 'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets], + 'FusedParamOffsets': [fused_offsets], + 'ParamOrder': [param_order], + 'Step': [step], + }, + attrs={ + 'alignment': self._alignment, + 'rank': local_rank if shard_inside_node else rank, + 'nranks': nproc_per_node if shard_inside_node else nranks, + 'apply_weight_decay': apply_weight_decay, + 'moment1': 0.0, + 'moment2': 0.0, + 'beta1': self._beta1, + 'beta2': self._beta2, + }) main_block = self.helper.main_program.global_block() self._create_global_learning_rate() @@ -351,7 +433,8 @@ class DistributedFusedLamb(Optimizer): 'max_global_grad_norm': self._max_global_grad_norm, 'clip_after_allreduce': self._clip_after_allreduce, 'rank': rank, - 'ring_id': self._ring_id, + 'nranks': nranks, + 'ring_id': ring_ids, 'use_master_param_norm': self._use_master_param_norm, 'is_grad_scaled_by_nranks': self._is_grad_scaled_by_nranks, 'acc_steps': self._gradient_accumulation_steps, -- GitLab