未验证 提交 c770053c 编写于 作者: S sneaxiy 提交者: GitHub

Add use_hierarchical_allreduce for DistributedFusedLAMB (#44821)

* add use_hierarchical_allreduce

* support hierarchical allreduce for more cases
上级 5ad3228c
......@@ -152,6 +152,9 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int>>("ring_id",
"The ring id of the NCCL communicator.")
.SetDefault({0});
AddAttr<bool>("use_hierarchical_allreduce",
"Whether to use hierarchical allreduce")
.SetDefault(false);
AddComment("The DistributedFusedLamb optimizer.");
}
};
......
......@@ -1614,15 +1614,20 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
const auto &ring_ids = ctx.Attr<std::vector<int>>("ring_id");
auto use_master_param_norm = ctx.Attr<bool>("use_master_param_norm");
auto is_grad_scaled_by_nranks = ctx.Attr<bool>("is_grad_scaled_by_nranks");
auto use_hierarchical_allreduce =
ctx.Attr<bool>("use_hierarchical_allreduce");
VLOG(10) << "max_global_grad_norm = " << max_global_grad_norm
<< " , clip_after_allreduce = " << clip_after_allreduce
<< " , use_master_param_norm = " << use_master_param_norm
<< " , is_grad_scaled_by_nranks = " << is_grad_scaled_by_nranks
<< " , local_shard = " << local_shard;
<< " , local_shard = " << local_shard
<< " , use_hierarchical_allreduce = "
<< use_hierarchical_allreduce;
// Step 6: allreduce + global norm gradient clip
int64_t global_rank = 0, local_rank = 0;
ncclComm_t global_comm = nullptr, local_comm = nullptr;
ncclComm_t global_comm = nullptr, local_comm = nullptr,
external_comm = nullptr;
if (nranks > 1) {
auto *nccl_comm_handle =
platform::NCCLCommContext::Instance().Get(ring_ids[0], place);
......@@ -1634,6 +1639,11 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
platform::NCCLCommContext::Instance().Get(ring_ids[1], place);
local_comm = local_nccl_comm_handle->comm();
local_rank = local_nccl_comm_handle->rank();
if (use_hierarchical_allreduce) {
external_comm = platform::NCCLCommContext::Instance()
.Get(ring_ids[2], place)
->comm();
}
} else {
local_comm = global_comm;
local_rank = global_rank;
......@@ -1686,6 +1696,39 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
if (clip_after_allreduce) {
// (1) ReduceScater first
if (local_shard) {
if (use_hierarchical_allreduce) {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(
fp32_sum_grad,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(
fp16_sum_grad,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
} else {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
......@@ -1700,6 +1743,7 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
global_comm,
stream,
dev_ctx);
}
fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else {
......@@ -1794,6 +1838,41 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
}
// (3) Do ReduceScatter with scale
if (local_shard) {
if (use_hierarchical_allreduce) {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx,
fp32_scale);
NCCLReduceScatterWithScale(
fp32_sum_grad,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx,
fp16_scale);
NCCLReduceScatterWithScale(
fp16_sum_grad,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
} else {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
......@@ -1810,6 +1889,7 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
stream,
dev_ctx,
fp16_scale);
}
fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else {
......@@ -1836,6 +1916,39 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
}
} else {
if (local_shard) {
if (use_hierarchical_allreduce) {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(
fp32_sum_grad,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(
fp16_sum_grad,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
} else {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
......@@ -1850,6 +1963,7 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
global_comm,
stream,
dev_ctx);
}
fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else {
......
......@@ -101,6 +101,7 @@ class DistributedFusedLamb(Optimizer):
gradient_accumulation_steps=1,
use_master_acc_grad=True,
nproc_per_node=None,
use_hierarchical_allreduce=False,
name=None):
assert not framework._non_static_mode(
), "DistributedFusedLamb does not support dygraph mode"
......@@ -129,6 +130,7 @@ class DistributedFusedLamb(Optimizer):
self._gradient_accumulation_steps = gradient_accumulation_steps
self._use_master_acc_grad = use_master_acc_grad
self._nproc_per_node = nproc_per_node
self._use_hierarchical_allreduce = use_hierarchical_allreduce
assert self._gradient_accumulation_steps >= 1
self.helper = LayerHelper('distributed_fused_lamb')
......@@ -305,6 +307,7 @@ class DistributedFusedLamb(Optimizer):
list(range(nranks)), 0)
ring_ids.append(ring_id)
use_hierarchical_allreduce = False
if node_num > 1 and len(ring_ids) <= 1 and shard_inside_node:
local_group_ranks = list(
range(node_id * nproc_per_node, (node_id + 1) * nproc_per_node))
......@@ -312,6 +315,14 @@ class DistributedFusedLamb(Optimizer):
1)
ring_ids.append(ring_id)
if self._use_hierarchical_allreduce and nranks > nproc_per_node:
use_hierarchical_allreduce = True
outer_group_ranks = list(
range(rank % nproc_per_node, nranks, nproc_per_node))
ring_id = init_communicator(startup_block, rank,
outer_group_ranks, ring_ids[-1] + 1)
ring_ids.append(ring_id)
scale = self._get_or_create_scale()
params = [p for p, _ in params_grads]
......@@ -439,5 +450,6 @@ class DistributedFusedLamb(Optimizer):
'is_grad_scaled_by_nranks': self._is_grad_scaled_by_nranks,
'acc_steps': self._gradient_accumulation_steps,
'use_master_acc_grad': self._use_master_acc_grad,
'use_hierarchical_allreduce': use_hierarchical_allreduce,
})
return [lamb_op]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册