未验证 提交 c770053c 编写于 作者: S sneaxiy 提交者: GitHub

Add use_hierarchical_allreduce for DistributedFusedLAMB (#44821)

* add use_hierarchical_allreduce

* support hierarchical allreduce for more cases
上级 5ad3228c
...@@ -152,6 +152,9 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -152,6 +152,9 @@ class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int>>("ring_id", AddAttr<std::vector<int>>("ring_id",
"The ring id of the NCCL communicator.") "The ring id of the NCCL communicator.")
.SetDefault({0}); .SetDefault({0});
AddAttr<bool>("use_hierarchical_allreduce",
"Whether to use hierarchical allreduce")
.SetDefault(false);
AddComment("The DistributedFusedLamb optimizer."); AddComment("The DistributedFusedLamb optimizer.");
} }
}; };
......
...@@ -1614,15 +1614,20 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T> ...@@ -1614,15 +1614,20 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
const auto &ring_ids = ctx.Attr<std::vector<int>>("ring_id"); const auto &ring_ids = ctx.Attr<std::vector<int>>("ring_id");
auto use_master_param_norm = ctx.Attr<bool>("use_master_param_norm"); auto use_master_param_norm = ctx.Attr<bool>("use_master_param_norm");
auto is_grad_scaled_by_nranks = ctx.Attr<bool>("is_grad_scaled_by_nranks"); auto is_grad_scaled_by_nranks = ctx.Attr<bool>("is_grad_scaled_by_nranks");
auto use_hierarchical_allreduce =
ctx.Attr<bool>("use_hierarchical_allreduce");
VLOG(10) << "max_global_grad_norm = " << max_global_grad_norm VLOG(10) << "max_global_grad_norm = " << max_global_grad_norm
<< " , clip_after_allreduce = " << clip_after_allreduce << " , clip_after_allreduce = " << clip_after_allreduce
<< " , use_master_param_norm = " << use_master_param_norm << " , use_master_param_norm = " << use_master_param_norm
<< " , is_grad_scaled_by_nranks = " << is_grad_scaled_by_nranks << " , is_grad_scaled_by_nranks = " << is_grad_scaled_by_nranks
<< " , local_shard = " << local_shard; << " , local_shard = " << local_shard
<< " , use_hierarchical_allreduce = "
<< use_hierarchical_allreduce;
// Step 6: allreduce + global norm gradient clip // Step 6: allreduce + global norm gradient clip
int64_t global_rank = 0, local_rank = 0; int64_t global_rank = 0, local_rank = 0;
ncclComm_t global_comm = nullptr, local_comm = nullptr; ncclComm_t global_comm = nullptr, local_comm = nullptr,
external_comm = nullptr;
if (nranks > 1) { if (nranks > 1) {
auto *nccl_comm_handle = auto *nccl_comm_handle =
platform::NCCLCommContext::Instance().Get(ring_ids[0], place); platform::NCCLCommContext::Instance().Get(ring_ids[0], place);
...@@ -1634,6 +1639,11 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T> ...@@ -1634,6 +1639,11 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
platform::NCCLCommContext::Instance().Get(ring_ids[1], place); platform::NCCLCommContext::Instance().Get(ring_ids[1], place);
local_comm = local_nccl_comm_handle->comm(); local_comm = local_nccl_comm_handle->comm();
local_rank = local_nccl_comm_handle->rank(); local_rank = local_nccl_comm_handle->rank();
if (use_hierarchical_allreduce) {
external_comm = platform::NCCLCommContext::Instance()
.Get(ring_ids[2], place)
->comm();
}
} else { } else {
local_comm = global_comm; local_comm = global_comm;
local_rank = global_rank; local_rank = global_rank;
...@@ -1686,20 +1696,54 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T> ...@@ -1686,20 +1696,54 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
if (clip_after_allreduce) { if (clip_after_allreduce) {
// (1) ReduceScater first // (1) ReduceScater first
if (local_shard) { if (local_shard) {
NCCLAllReduceWithScale(fp32_grad, if (use_hierarchical_allreduce) {
fp32_sum_grad, NCCLAllReduceWithScale(fp32_grad,
fp32_numel, fp32_sum_grad,
nranks, fp32_numel,
global_comm, nranks / num_devices,
stream, external_comm,
dev_ctx); stream,
NCCLAllReduceWithScale(fp16_grad, dev_ctx);
fp16_sum_grad, NCCLReduceScatterWithScale(
fp16_numel, fp32_sum_grad,
nranks, fp32_sum_grad + local_rank * fp32_numel_each_device,
global_comm, fp32_numel_each_device,
stream, num_devices,
dev_ctx); local_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(
fp16_sum_grad,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
} else {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
nranks,
global_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
nranks,
global_comm,
stream,
dev_ctx);
}
fp32_sum_grad += (local_rank * fp32_numel_each_device); fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device); fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else { } else {
...@@ -1794,22 +1838,58 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T> ...@@ -1794,22 +1838,58 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
} }
// (3) Do ReduceScatter with scale // (3) Do ReduceScatter with scale
if (local_shard) { if (local_shard) {
NCCLAllReduceWithScale(fp32_grad, if (use_hierarchical_allreduce) {
fp32_sum_grad, NCCLAllReduceWithScale(fp32_grad,
fp32_numel, fp32_sum_grad,
nranks, fp32_numel,
global_comm, nranks / num_devices,
stream, external_comm,
dev_ctx, stream,
fp32_scale); dev_ctx,
NCCLAllReduceWithScale(fp16_grad, fp32_scale);
fp16_sum_grad, NCCLReduceScatterWithScale(
fp16_numel, fp32_sum_grad,
nranks, fp32_sum_grad + local_rank * fp32_numel_each_device,
global_comm, fp32_numel_each_device,
stream, num_devices,
dev_ctx, local_comm,
fp16_scale); stream,
dev_ctx);
NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx,
fp16_scale);
NCCLReduceScatterWithScale(
fp16_sum_grad,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
} else {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
nranks,
global_comm,
stream,
dev_ctx,
fp32_scale);
NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
nranks,
global_comm,
stream,
dev_ctx,
fp16_scale);
}
fp32_sum_grad += (local_rank * fp32_numel_each_device); fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device); fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else { } else {
...@@ -1836,20 +1916,54 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T> ...@@ -1836,20 +1916,54 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
} }
} else { } else {
if (local_shard) { if (local_shard) {
NCCLAllReduceWithScale(fp32_grad, if (use_hierarchical_allreduce) {
fp32_sum_grad, NCCLAllReduceWithScale(fp32_grad,
fp32_numel, fp32_sum_grad,
nranks, fp32_numel,
global_comm, nranks / num_devices,
stream, external_comm,
dev_ctx); stream,
NCCLAllReduceWithScale(fp16_grad, dev_ctx);
fp16_sum_grad, NCCLReduceScatterWithScale(
fp16_numel, fp32_sum_grad,
nranks, fp32_sum_grad + local_rank * fp32_numel_each_device,
global_comm, fp32_numel_each_device,
stream, num_devices,
dev_ctx); local_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(
fp16_sum_grad,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
} else {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
nranks,
global_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
nranks,
global_comm,
stream,
dev_ctx);
}
fp32_sum_grad += (local_rank * fp32_numel_each_device); fp32_sum_grad += (local_rank * fp32_numel_each_device);
fp16_sum_grad += (local_rank * fp16_numel_each_device); fp16_sum_grad += (local_rank * fp16_numel_each_device);
} else { } else {
......
...@@ -101,6 +101,7 @@ class DistributedFusedLamb(Optimizer): ...@@ -101,6 +101,7 @@ class DistributedFusedLamb(Optimizer):
gradient_accumulation_steps=1, gradient_accumulation_steps=1,
use_master_acc_grad=True, use_master_acc_grad=True,
nproc_per_node=None, nproc_per_node=None,
use_hierarchical_allreduce=False,
name=None): name=None):
assert not framework._non_static_mode( assert not framework._non_static_mode(
), "DistributedFusedLamb does not support dygraph mode" ), "DistributedFusedLamb does not support dygraph mode"
...@@ -129,6 +130,7 @@ class DistributedFusedLamb(Optimizer): ...@@ -129,6 +130,7 @@ class DistributedFusedLamb(Optimizer):
self._gradient_accumulation_steps = gradient_accumulation_steps self._gradient_accumulation_steps = gradient_accumulation_steps
self._use_master_acc_grad = use_master_acc_grad self._use_master_acc_grad = use_master_acc_grad
self._nproc_per_node = nproc_per_node self._nproc_per_node = nproc_per_node
self._use_hierarchical_allreduce = use_hierarchical_allreduce
assert self._gradient_accumulation_steps >= 1 assert self._gradient_accumulation_steps >= 1
self.helper = LayerHelper('distributed_fused_lamb') self.helper = LayerHelper('distributed_fused_lamb')
...@@ -305,6 +307,7 @@ class DistributedFusedLamb(Optimizer): ...@@ -305,6 +307,7 @@ class DistributedFusedLamb(Optimizer):
list(range(nranks)), 0) list(range(nranks)), 0)
ring_ids.append(ring_id) ring_ids.append(ring_id)
use_hierarchical_allreduce = False
if node_num > 1 and len(ring_ids) <= 1 and shard_inside_node: if node_num > 1 and len(ring_ids) <= 1 and shard_inside_node:
local_group_ranks = list( local_group_ranks = list(
range(node_id * nproc_per_node, (node_id + 1) * nproc_per_node)) range(node_id * nproc_per_node, (node_id + 1) * nproc_per_node))
...@@ -312,6 +315,14 @@ class DistributedFusedLamb(Optimizer): ...@@ -312,6 +315,14 @@ class DistributedFusedLamb(Optimizer):
1) 1)
ring_ids.append(ring_id) ring_ids.append(ring_id)
if self._use_hierarchical_allreduce and nranks > nproc_per_node:
use_hierarchical_allreduce = True
outer_group_ranks = list(
range(rank % nproc_per_node, nranks, nproc_per_node))
ring_id = init_communicator(startup_block, rank,
outer_group_ranks, ring_ids[-1] + 1)
ring_ids.append(ring_id)
scale = self._get_or_create_scale() scale = self._get_or_create_scale()
params = [p for p, _ in params_grads] params = [p for p, _ in params_grads]
...@@ -439,5 +450,6 @@ class DistributedFusedLamb(Optimizer): ...@@ -439,5 +450,6 @@ class DistributedFusedLamb(Optimizer):
'is_grad_scaled_by_nranks': self._is_grad_scaled_by_nranks, 'is_grad_scaled_by_nranks': self._is_grad_scaled_by_nranks,
'acc_steps': self._gradient_accumulation_steps, 'acc_steps': self._gradient_accumulation_steps,
'use_master_acc_grad': self._use_master_acc_grad, 'use_master_acc_grad': self._use_master_acc_grad,
'use_hierarchical_allreduce': use_hierarchical_allreduce,
}) })
return [lamb_op] return [lamb_op]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册