diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc index f7b8dacfc5aa5f891ba8d4c2d9194f02a60f7fea..77d8682e33562e34afa09df0ce376439c62b3311 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc @@ -151,8 +151,8 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker { "Whether the input gradient has been scaled by nranks.") .SetDefault(true); AddAttr("nranks", "The world size.").SetDefault(1); - AddAttr>("ring_id", - "The ring id of the NCCL communicator.") + AddAttr>("ring_ids", + "The ring ids of the NCCL communicator.") .SetDefault({0}); AddAttr("use_hierarchical_allreduce", "Whether to use hierarchical allreduce") diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu index 908be3cd41d21fbd0f73a8218242fe988426f5bf..0289e70d1da567325aee7bf98881d2151b3bfe99 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu @@ -1644,7 +1644,7 @@ class DistributedFusedLambOpKernel "The nranks must be exactly divided by num_devices.")); bool local_shard = (nranks > num_devices); - const auto &ring_ids = ctx.Attr>("ring_id"); + const auto &ring_ids = ctx.Attr>("ring_ids"); auto use_master_param_norm = ctx.Attr("use_master_param_norm"); auto is_grad_scaled_by_nranks = ctx.Attr("is_grad_scaled_by_nranks"); auto use_hierarchical_allreduce = diff --git a/python/paddle/incubate/optimizer/distributed_fused_lamb.py b/python/paddle/incubate/optimizer/distributed_fused_lamb.py index 3964f431ac5540c7a576b0ef58f13488886570c7..9a76db3be3d9256df5de98d0bc199966cfcc65ce 100644 --- a/python/paddle/incubate/optimizer/distributed_fused_lamb.py +++ b/python/paddle/incubate/optimizer/distributed_fused_lamb.py @@ -479,7 +479,7 @@ class DistributedFusedLamb(Optimizer): 'clip_after_allreduce': self._clip_after_allreduce, 'rank': rank, 'nranks': nranks, - 'ring_id': ring_ids, + 'ring_ids': ring_ids, 'use_master_param_norm': self._use_master_param_norm, 'is_grad_scaled_by_nranks': self._is_grad_scaled_by_nranks, 'acc_steps': self._gradient_accumulation_steps,