未验证 提交 a348a423 编写于 作者: TaoTao Li's avatar TaoTao Li 提交者: GitHub

rename distributed_fused_lamb attr ring_id->ring_ids (#51000)

上级 2f900965
...@@ -151,8 +151,8 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -151,8 +151,8 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker {
"Whether the input gradient has been scaled by nranks.") "Whether the input gradient has been scaled by nranks.")
.SetDefault(true); .SetDefault(true);
AddAttr<int64_t>("nranks", "The world size.").SetDefault(1); AddAttr<int64_t>("nranks", "The world size.").SetDefault(1);
AddAttr<std::vector<int>>("ring_id", AddAttr<std::vector<int>>("ring_ids",
"The ring id of the NCCL communicator.") "The ring ids of the NCCL communicator.")
.SetDefault({0}); .SetDefault({0});
AddAttr<bool>("use_hierarchical_allreduce", AddAttr<bool>("use_hierarchical_allreduce",
"Whether to use hierarchical allreduce") "Whether to use hierarchical allreduce")
......
...@@ -1644,7 +1644,7 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T> ...@@ -1644,7 +1644,7 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
"The nranks must be exactly divided by num_devices.")); "The nranks must be exactly divided by num_devices."));
bool local_shard = (nranks > num_devices); bool local_shard = (nranks > num_devices);
const auto &ring_ids = ctx.Attr<std::vector<int>>("ring_id"); const auto &ring_ids = ctx.Attr<std::vector<int>>("ring_ids");
auto use_master_param_norm = ctx.Attr<bool>("use_master_param_norm"); auto use_master_param_norm = ctx.Attr<bool>("use_master_param_norm");
auto is_grad_scaled_by_nranks = ctx.Attr<bool>("is_grad_scaled_by_nranks"); auto is_grad_scaled_by_nranks = ctx.Attr<bool>("is_grad_scaled_by_nranks");
auto use_hierarchical_allreduce = auto use_hierarchical_allreduce =
......
...@@ -479,7 +479,7 @@ class DistributedFusedLamb(Optimizer): ...@@ -479,7 +479,7 @@ class DistributedFusedLamb(Optimizer):
'clip_after_allreduce': self._clip_after_allreduce, 'clip_after_allreduce': self._clip_after_allreduce,
'rank': rank, 'rank': rank,
'nranks': nranks, 'nranks': nranks,
'ring_id': ring_ids, 'ring_ids': ring_ids,
'use_master_param_norm': self._use_master_param_norm, 'use_master_param_norm': self._use_master_param_norm,
'is_grad_scaled_by_nranks': self._is_grad_scaled_by_nranks, 'is_grad_scaled_by_nranks': self._is_grad_scaled_by_nranks,
'acc_steps': self._gradient_accumulation_steps, 'acc_steps': self._gradient_accumulation_steps,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册