diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu index f1b852301d4d9fe6d0dc9e8c9f3f2b23cbaa27b8..53a5fa4706cc03f07f2fc7fb4ad6c26bf7c26d4b 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu @@ -1697,37 +1697,39 @@ class DistributedFusedLambOpKernel // (1) ReduceScater first if (local_shard) { if (use_hierarchical_allreduce) { - NCCLAllReduceWithScale(fp32_grad, - fp32_sum_grad, - fp32_numel, - nranks / num_devices, - external_comm, - stream, - dev_ctx); NCCLReduceScatterWithScale( - fp32_sum_grad, + fp32_grad, fp32_sum_grad + local_rank * fp32_numel_each_device, fp32_numel_each_device, num_devices, local_comm, stream, dev_ctx); + NCCLAllReduceWithScale( + fp32_sum_grad + local_rank * fp32_numel_each_device, + fp32_sum_grad + local_rank * fp32_numel_each_device, + fp32_numel_each_device, + nranks / num_devices, + external_comm, + stream, + dev_ctx); - NCCLAllReduceWithScale(fp16_grad, - fp16_sum_grad, - fp16_numel, - nranks / num_devices, - external_comm, - stream, - dev_ctx); NCCLReduceScatterWithScale( - fp16_sum_grad, + fp16_grad, fp16_sum_grad + local_rank * fp16_numel_each_device, fp16_numel_each_device, num_devices, local_comm, stream, dev_ctx); + NCCLAllReduceWithScale( + fp16_sum_grad + local_rank * fp16_numel_each_device, + fp16_sum_grad + local_rank * fp16_numel_each_device, + fp16_numel_each_device, + nranks / num_devices, + external_comm, + stream, + dev_ctx); } else { NCCLAllReduceWithScale(fp32_grad, fp32_sum_grad, @@ -1839,38 +1841,40 @@ class DistributedFusedLambOpKernel // (3) Do ReduceScatter with scale if (local_shard) { if (use_hierarchical_allreduce) { - NCCLAllReduceWithScale(fp32_grad, - fp32_sum_grad, - fp32_numel, - nranks / num_devices, - external_comm, - stream, - dev_ctx, - fp32_scale); NCCLReduceScatterWithScale( - fp32_sum_grad, + fp32_grad, fp32_sum_grad + local_rank * fp32_numel_each_device, fp32_numel_each_device, num_devices, local_comm, stream, + dev_ctx, + fp32_scale); + NCCLAllReduceWithScale( + fp32_sum_grad + local_rank * fp32_numel_each_device, + fp32_sum_grad + local_rank * fp32_numel_each_device, + fp32_numel_each_device, + nranks / num_devices, + external_comm, + stream, dev_ctx); - NCCLAllReduceWithScale(fp16_grad, - fp16_sum_grad, - fp16_numel, - nranks / num_devices, - external_comm, - stream, - dev_ctx, - fp16_scale); NCCLReduceScatterWithScale( - fp16_sum_grad, + fp16_grad, fp16_sum_grad + local_rank * fp16_numel_each_device, fp16_numel_each_device, num_devices, local_comm, stream, + dev_ctx, + fp16_scale); + NCCLAllReduceWithScale( + fp16_sum_grad + local_rank * fp16_numel_each_device, + fp16_sum_grad + local_rank * fp16_numel_each_device, + fp16_numel_each_device, + nranks / num_devices, + external_comm, + stream, dev_ctx); } else { NCCLAllReduceWithScale(fp32_grad, @@ -1917,37 +1921,39 @@ class DistributedFusedLambOpKernel } else { if (local_shard) { if (use_hierarchical_allreduce) { - NCCLAllReduceWithScale(fp32_grad, - fp32_sum_grad, - fp32_numel, - nranks / num_devices, - external_comm, - stream, - dev_ctx); NCCLReduceScatterWithScale( - fp32_sum_grad, + fp32_grad, fp32_sum_grad + local_rank * fp32_numel_each_device, fp32_numel_each_device, num_devices, local_comm, stream, dev_ctx); + NCCLAllReduceWithScale( + fp32_sum_grad + local_rank * fp32_numel_each_device, + fp32_sum_grad + local_rank * fp32_numel_each_device, + fp32_numel_each_device, + nranks / num_devices, + external_comm, + stream, + dev_ctx); - NCCLAllReduceWithScale(fp16_grad, - fp16_sum_grad, - fp16_numel, - nranks / num_devices, - external_comm, - stream, - dev_ctx); NCCLReduceScatterWithScale( - fp16_sum_grad, + fp16_grad, fp16_sum_grad + local_rank * fp16_numel_each_device, fp16_numel_each_device, num_devices, local_comm, stream, dev_ctx); + NCCLAllReduceWithScale( + fp16_sum_grad + local_rank * fp16_numel_each_device, + fp16_sum_grad + local_rank * fp16_numel_each_device, + fp16_numel_each_device, + nranks / num_devices, + external_comm, + stream, + dev_ctx); } else { NCCLAllReduceWithScale(fp32_grad, fp32_sum_grad,