未验证 提交 1f9e2742 编写于 作者: S sneaxiy 提交者: GitHub

opt allreduce (#44843)

上级 d3e90680
......@@ -1697,37 +1697,39 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
// (1) ReduceScater first
if (local_shard) {
if (use_hierarchical_allreduce) {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(
fp32_sum_grad,
fp32_grad,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
NCCLAllReduceWithScale(
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(
fp16_sum_grad,
fp16_grad,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
} else {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
......@@ -1839,38 +1841,40 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
// (3) Do ReduceScatter with scale
if (local_shard) {
if (use_hierarchical_allreduce) {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx,
fp32_scale);
NCCLReduceScatterWithScale(
fp32_sum_grad,
fp32_grad,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
dev_ctx,
fp32_scale);
NCCLAllReduceWithScale(
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx,
fp16_scale);
dev_ctx);
NCCLReduceScatterWithScale(
fp16_sum_grad,
fp16_grad,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx,
fp16_scale);
NCCLAllReduceWithScale(
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
} else {
NCCLAllReduceWithScale(fp32_grad,
......@@ -1917,37 +1921,39 @@ class DistributedFusedLambOpKernel<phi::GPUContext, T>
} else {
if (local_shard) {
if (use_hierarchical_allreduce) {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
fp32_numel,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(
fp32_sum_grad,
fp32_grad,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(fp16_grad,
fp16_sum_grad,
fp16_numel,
NCCLAllReduceWithScale(
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_sum_grad + local_rank * fp32_numel_each_device,
fp32_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(
fp16_sum_grad,
fp16_grad,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
num_devices,
local_comm,
stream,
dev_ctx);
NCCLAllReduceWithScale(
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_sum_grad + local_rank * fp16_numel_each_device,
fp16_numel_each_device,
nranks / num_devices,
external_comm,
stream,
dev_ctx);
} else {
NCCLAllReduceWithScale(fp32_grad,
fp32_sum_grad,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册