From b84edd9069d64bbb896ffbc4a851a806552553ce Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Tue, 18 Oct 2022 16:33:02 +0800 Subject: [PATCH] [cherry-pick] Fix perf issues of mp/pp/fuse in eager mode (#47071) * [Dygraph] Fix performance of pp+mp by using send/recv_calc_stream instead of send/recv (#46116) * [Dygraph] Fix Perf of FusedFeedForward and FusedAttention with AllReduce (#46780) * update --- .../operators/fused/fused_attention_op.cu | 6 +++-- .../operators/fused/fused_feedforward_op.cu | 6 +++-- .../distributed/fleet/layers/mpu/mp_ops.py | 21 ++++++++++++++++- .../pp_utils/p2p_communication.py | 23 ++++++++----------- 4 files changed, 37 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 059d94031a..5ff01e0bc1 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -30,7 +30,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/math_function.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -#include "paddle/fluid/distributed/collective/ProcessGroup.h" +#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #endif @@ -50,13 +50,15 @@ static void AllReduce(framework::Tensor &tensor, // NOLINT if (map->has(ring_id)) { paddle::distributed::ProcessGroup *pg = map->get(ring_id); + auto pg_nccl = static_cast(pg); + std::vector in_tensor; std::vector out_tensor; in_tensor.push_back(tensor); out_tensor.push_back(tensor); paddle::distributed::AllreduceOptions opts; opts.reduce_op = distributed::ReduceOp::SUM; - auto task = pg->AllReduce(in_tensor, out_tensor, opts); + auto task = pg_nccl->AllReduce(in_tensor, out_tensor, opts, true, true); task->Wait(); } else { auto dtype = platform::ToNCCLDataType( diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cu b/paddle/fluid/operators/fused/fused_feedforward_op.cu index 33d1e89bf2..95f1380656 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cu +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cu @@ -23,7 +23,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/elementwise_functor.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -#include "paddle/fluid/distributed/collective/ProcessGroup.h" +#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #endif @@ -43,13 +43,15 @@ static void AllReduce(framework::Tensor& tensor, // NOLINT if (map->has(ring_id)) { paddle::distributed::ProcessGroup* pg = map->get(ring_id); + auto pg_nccl = static_cast(pg); + std::vector in_tensor; std::vector out_tensor; in_tensor.push_back(tensor); out_tensor.push_back(tensor); paddle::distributed::AllreduceOptions opts; opts.reduce_op = distributed::ReduceOp::SUM; - auto task = pg->AllReduce(in_tensor, out_tensor, opts); + auto task = pg_nccl->AllReduce(in_tensor, out_tensor, opts, true, true); task->Wait(); } else { auto dtype = platform::ToNCCLDataType( diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py index a2f3bde6cf..18e7b66177 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_ops.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_ops.py @@ -43,7 +43,26 @@ def _c_identity(tensor, group=None): return ring_id = 0 if group is None else group.id - if _non_static_mode(): + if in_dygraph_mode(): + from paddle.autograd import PyLayer + + class c_identity_eager(PyLayer): + + @staticmethod + def forward(ctx, tensor): + return _legacy_C_ops.c_identity(tensor, 'use_calc_stream', True, + 'ring_id', group.id, + 'use_model_parallel', True) + + @staticmethod + def backward(ctx, dy): + op_type = collective._get_reduce_op(ReduceOp.SUM, "_c_identity") + group.process_group.allreduce_on_calc_stream(dy, op_type) + return dy + + return c_identity_eager.apply(tensor) + + elif _in_legacy_dygraph(): return _legacy_C_ops.c_identity(tensor, 'use_calc_stream', True, 'ring_id', ring_id, 'use_model_parallel', True) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index e2ca6f8d2a..c1cf0527e1 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -173,7 +173,9 @@ def _partial_send_op(tensor, group, use_calc_stream, ring_id, dst, nranks, elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group - return group.process_group.send_partial(tensor, dst, nranks, rank_id) + comm_op = group.process_group.send_partial_on_calc_stream \ + if use_calc_stream else group.process_group.send_partial + return comm_op(tensor, dst, nranks, rank_id) def send_partial(tensor, @@ -212,12 +214,9 @@ def _partial_recv_op(tensor, group, use_calc_stream, ring_id, src, nranks, elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group - task = group.process_group.recv_partial(tensor, src, nranks, rank_id) - if use_calc_stream: - task.wait() - return None - else: - return task + comm_op = group.process_group.recv_partial_on_calc_stream \ + if use_calc_stream else group.process_group.recv_partial + return comm_op(tensor, src, nranks, rank_id) def recv_partial(tensor, @@ -255,13 +254,9 @@ def _partial_allgather_op(tensor, group, use_calc_stream, ring_id, nranks, elif in_dygraph_mode(): group = paddle.distributed.collective._get_default_group( ) if group is None else group - task = group.process_group.all_gather_partial(tensor, tensor, nranks, - rank_id) - if use_calc_stream: - task.wait() - return None - else: - return task + comm_op = group.process_group.all_gather_partial_on_calc_stream \ + if use_calc_stream else group.process_group.all_gather_partial + return comm_op(tensor, tensor, nranks, rank_id) def allgather_partial(tensor, -- GitLab