未验证 提交 1f9e2742 编写于 作者: S sneaxiy 提交者: GitHub

opt allreduce (#44843)

上级 d3e90680
...@@ -1697,37 +1697,39 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T> ...@@ -1697,37 +1697,39 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
// (1) ReduceScater first // (1) ReduceScater first
if (local_shard) { if (local_shard) {
if (use_hierarchical_allreduce) { if (use_hierarchical_allreduce) {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale( NCCLReduceScatterWithScale(
fp32_sum_grad, fp32_grad,
fp32_sum_grad + local_rank * fp32_numel_each_device, fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device, fp32_numel_each_device,
num_devices, num_devices,
local_comm, local_comm,
stream, stream,
dev_ctx); dev_ctx);
NCCLAllReduceWithScale(
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale( NCCLReduceScatterWithScale(
fp16_sum_grad, fp16_grad,
fp16_sum_grad + local_rank * fp16_numel_each_device, fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device, fp16_numel_each_device,
num_devices, num_devices,
local_comm, local_comm,
stream, stream,
dev_ctx); dev_ctx);
NCCLAllReduceWithScale(
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
} else { } else {
NCCLAllReduceWithScale(fp32_grad, NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad, fp32_sum_grad,
...@@ -1839,38 +1841,40 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T> ...@@ -1839,38 +1841,40 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
// (3) Do ReduceScatter with scale // (3) Do ReduceScatter with scale
if (local_shard) { if (local_shard) {
if (use_hierarchical_allreduce) { if (use_hierarchical_allreduce) {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx,
fp32_scale);
NCCLReduceScatterWithScale( NCCLReduceScatterWithScale(
fp32_sum_grad, fp32_grad,
fp32_sum_grad + local_rank * fp32_numel_each_device, fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device, fp32_numel_each_device,
num_devices, num_devices,
local_comm, local_comm,
stream, stream,
dev_ctx,
fp32_scale);
NCCLAllReduceWithScale(
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx); dev_ctx);
NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx,
fp16_scale);
NCCLReduceScatterWithScale( NCCLReduceScatterWithScale(
fp16_sum_grad, fp16_grad,
fp16_sum_grad + local_rank * fp16_numel_each_device, fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device, fp16_numel_each_device,
num_devices, num_devices,
local_comm, local_comm,
stream, stream,
dev_ctx,
fp16_scale);
NCCLAllReduceWithScale(
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx); dev_ctx);
} else { } else {
NCCLAllReduceWithScale(fp32_grad, NCCLAllReduceWithScale(fp32_grad,
...@@ -1917,37 +1921,39 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T> ...@@ -1917,37 +1921,39 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
} else { } else {
if (local_shard) { if (local_shard) {
if (use_hierarchical_allreduce) { if (use_hierarchical_allreduce) {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale( NCCLReduceScatterWithScale(
fp32_sum_grad, fp32_grad,
fp32_sum_grad + local_rank * fp32_numel_each_device, fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device, fp32_numel_each_device,
num_devices, num_devices,
local_comm, local_comm,
stream, stream,
dev_ctx); dev_ctx);
NCCLAllReduceWithScale(
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale( NCCLReduceScatterWithScale(
fp16_sum_grad, fp16_grad,
fp16_sum_grad + local_rank * fp16_numel_each_device, fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device, fp16_numel_each_device,
num_devices, num_devices,
local_comm, local_comm,
stream, stream,
dev_ctx); dev_ctx);
NCCLAllReduceWithScale(
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
} else { } else {
NCCLAllReduceWithScale(fp32_grad, NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad, fp32_sum_grad,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册