From cbf26bb15473d168d1bfbe6823063dd7ffa182bf Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Fri, 2 Sep 2022 11:01:02 +0800 Subject: [PATCH] fix perf bugs of mp in eager dygraph (#45646) --- python/paddle/distributed/collective.py | 26 +++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index e2dae09b48..74e350b4a5 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -1402,9 +1402,9 @@ def _mp_allreduce(tensor, """ if group is not None and not group.is_member(): return - ring_id = 0 if group is None else group.id if in_dygraph_mode(): + group = _get_default_group() if group is None else group assert op == ReduceOp.SUM, "Unknown parameter: {}.".format(op) from paddle.autograd import PyLayer @@ -1412,14 +1412,19 @@ def _mp_allreduce(tensor, class mp_allreduce_eager(PyLayer): @staticmethod - def forward(ctx, tensor, use_calc_stream, ring_id, + def forward(ctx, tensor, group, use_calc_stream, use_model_parallel): - ctx.ring_id = ring_id - return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream', - use_calc_stream, - 'ring_id', ring_id, - "use_model_parallel", - use_model_parallel) + ctx.ring_id = group.id + + if use_calc_stream: + op_type = _get_reduce_op(op, "_mp_allreduce") + group.process_group.allreduce_on_calc_stream( + tensor, op_type) + return tensor + else: + return _legacy_C_ops.c_allreduce_sum_( + tensor, 'use_calc_stream', use_calc_stream, 'ring_id', + ring_id, "use_model_parallel", use_model_parallel) @staticmethod def backward(ctx, dy): @@ -1427,10 +1432,11 @@ def _mp_allreduce(tensor, 'ring_id', ctx.ring_id, 'use_model_parallel', True) - return mp_allreduce_eager.apply(tensor, use_calc_stream, ring_id, + return mp_allreduce_eager.apply(tensor, group, use_calc_stream, use_model_parallel) - elif _in_legacy_dygraph(): + ring_id = 0 if group is None else group.id + if _in_legacy_dygraph(): if op == ReduceOp.SUM: return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream', use_calc_stream, 'ring_id', -- GitLab