未验证 提交 a348a423 编写于 作者: TaoTao Li's avatar TaoTao Li 提交者: GitHub

rename distributed_fused_lamb attr ring_id->ring_ids (#51000)

上级 2f900965
......@@ -151,8 +151,8 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker {
"Whether the input gradient has been scaled by nranks.")
.SetDefault(true);
AddAttr<int64_t>("nranks", "The world size.").SetDefault(1);
AddAttr<std::vector<int>>("ring_id",
"The ring id of the NCCL communicator.")
AddAttr<std::vector<int>>("ring_ids",
"The ring ids of the NCCL communicator.")
.SetDefault({0});
AddAttr<bool>("use_hierarchical_allreduce",
"Whether to use hierarchical allreduce")
......
......@@ -1644,7 +1644,7 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
"The nranks must be exactly divided by num_devices."));
bool local_shard = (nranks > num_devices);
const auto &ring_ids = ctx.Attr<std::vector<int>>("ring_id");
const auto &ring_ids = ctx.Attr<std::vector<int>>("ring_ids");
auto use_master_param_norm = ctx.Attr<bool>("use_master_param_norm");
auto is_grad_scaled_by_nranks = ctx.Attr<bool>("is_grad_scaled_by_nranks");
auto use_hierarchical_allreduce =
......
......@@ -479,7 +479,7 @@ class DistributedFusedLamb(Optimizer):
'clip_after_allreduce': self._clip_after_allreduce,
'rank': rank,
'nranks': nranks,
'ring_id': ring_ids,
'ring_ids': ring_ids,
'use_master_param_norm': self._use_master_param_norm,
'is_grad_scaled_by_nranks': self._is_grad_scaled_by_nranks,
'acc_steps': self._gradient_accumulation_steps,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册