From e9e68c365801a2663b0332a9f050d844a81b0dea Mon Sep 17 00:00:00 2001 From: lilong12 Date: Wed, 6 Apr 2022 12:52:52 +0800 Subject: [PATCH] support group with one rank (#41398) --- python/paddle/distributed/collective.py | 58 +++++++++++++------ python/paddle/distributed/parallel.py | 30 +++++----- .../tests/unittests/process_group_gloo.py | 2 +- 3 files changed, 56 insertions(+), 34 deletions(-) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index ecd31386a2..a5ea528d13 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -16,7 +16,6 @@ import numpy as np import os from datetime import timedelta from ..fluid.layer_helper import LayerHelper -import paddle.fluid.framework as framework from ..fluid.framework import Variable from ..fluid.framework import in_dygraph_mode from ..fluid.framework import OpProtoHolder @@ -144,6 +143,16 @@ _default_store = None # the default tcp store _default_backend = None +def _set_default_backend(backend): + global _default_backend + _default_backend = backend + + +def _set_default_store(store): + global _default_store + _default_store = store + + def _get_group_map(): global _group_map if not _group_map: @@ -159,19 +168,29 @@ def _get_global_group(): def _get_group_map_by_name(): global _group_map_by_name - assert _default_group_name in _group_map_by_name, ( - "Call paddle.distributed.init_parallel_env first " - "to initialize the distributed environment.") return _group_map_by_name def _get_default_group(): + global _group_map_by_name assert _default_group_name in _group_map_by_name, ( "Call paddle.distributed.init_parallel_env first " "to initialize the distributed environment.") return _get_group_map_by_name()[_default_group_name] +def _set_group_map(gid, group): + global _group_map + assert gid not in _group_map + _group_map[gid] = group + + +def _set_group_map_by_name(name, group): + global _group_map_by_name + assert name not in _group_map_by_name + _group_map_by_name[name] = group + + def _new_ring_id(): return len(_get_group_map()) + max(_get_global_env().nrings, 9) @@ -208,6 +227,7 @@ def _new_process_group_impl(backend, pg_options, group_id=0): pg = None + assert backend in _valid_backend_list, "Unsupported backend: %s." % backend if backend == "gloo": pg = core.ProcessGroupGloo(store, rank, world_size, group_id) elif backend == "nccl": @@ -242,7 +262,7 @@ def barrier(group=None): if group is not None and not group.is_member(): return - if framework._in_eager_mode_ and in_dygraph_mode(): + if in_dygraph_mode(): group = _get_default_group() if group is None else group task = group.process_group.barrier() task.wait() @@ -290,22 +310,22 @@ def new_group(ranks=None, backend=None): """ global _group_map - if framework._in_eager_mode_: + if in_dygraph_mode(): global _default_group_name gid = _new_ring_id() group_name = _default_group_name + str(gid) global_group = _get_default_group() global_rank = global_group.rank global_ranks = global_group.ranks + backend = _default_backend if backend is None else backend if ranks is None: ranks = global_ranks assert len(ranks) <= len(global_ranks), ( "Size of new group must be less than or " "equal to that of the default global group.") size = len(ranks) - assert size > 1, "A group must have at least two memebers." ranks = sorted(ranks) - if global_rank in ranks: + if global_rank in ranks and size > 1: rank = ranks.index(global_rank) pg = _new_process_group_impl( backend, @@ -495,7 +515,7 @@ def broadcast(tensor, src, group=None, use_calc_stream=True): if not isinstance(src, int): raise ValueError("src should be int.") - if framework._in_eager_mode_ and in_dygraph_mode(): + if in_dygraph_mode(): group = _get_default_group() if group is None else group gsrc = group.get_group_rank(src) assert gsrc >= 0, ("src rank out of group, need global rank") @@ -579,7 +599,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True): if group is not None and not group.is_member(): return - if framework._in_eager_mode_ and in_dygraph_mode(): + if in_dygraph_mode(): if op == ReduceOp.SUM: op_type = core.ReduceOp.SUM elif op == ReduceOp.MAX: @@ -681,7 +701,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True): if group is not None and not group.is_member(): return - if framework._in_eager_mode_ and in_dygraph_mode(): + if in_dygraph_mode(): if op == ReduceOp.SUM: op_type = core.ReduceOp.SUM elif op == ReduceOp.MAX: @@ -802,7 +822,7 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): if group is not None and not group.is_member(): return - if framework._in_eager_mode_ and in_dygraph_mode(): + if in_dygraph_mode(): group = _get_default_group() if group is None else group out = paddle.concat(tensor_list) task = group.process_group.all_gather(tensor, out) @@ -899,7 +919,7 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True): if not isinstance(src, int): raise ValueError("src should be int.") - if framework._in_eager_mode_ and in_dygraph_mode(): + if in_dygraph_mode(): group = _get_default_group() if group is None else group gsrc = group.get_group_rank(src) rank = group.rank @@ -916,7 +936,7 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True): for _ in range(nranks): tensor_list.append(tensor) temp = paddle.concat(tensor_list, axis=0) - if framework._in_eager_mode_ and in_dygraph_mode(): + if in_dygraph_mode(): task = group.process_group.scatter(temp, tensor, gsrc) if use_calc_stream: task.wait() @@ -924,7 +944,7 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True): else: return task - if in_dygraph_mode(): + if _non_static_mode(): return _C_ops.c_scatter(temp, tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'nranks', nranks, 'root', gsrc) @@ -1694,14 +1714,14 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): if group is not None and not group.is_member(): return - if framework._in_eager_mode_ and in_dygraph_mode(): + if in_dygraph_mode(): group = _get_default_group() if group is None else group else: ring_id = 0 if group is None else group.id temp = paddle.concat(in_tensor_list, axis=0) nranks = len(in_tensor_list) - if framework._in_eager_mode_ and in_dygraph_mode(): + if in_dygraph_mode(): out = paddle.concat(out_tensor_list, axis=0) task = group.process_group.alltoall(temp, out) task.wait() @@ -1776,7 +1796,7 @@ def send(tensor, dst=0, group=None, use_calc_stream=True): if group is not None and not group.is_member(): return - if framework._in_eager_mode_ and in_dygraph_mode(): + if in_dygraph_mode(): group = _get_default_group() if group is None else group task = group.process_group.send(tensor, dst) if use_calc_stream: @@ -1839,7 +1859,7 @@ def recv(tensor, src=0, group=None, use_calc_stream=True): if group is not None and not group.is_member(): return - if framework._in_eager_mode_ and in_dygraph_mode(): + if in_dygraph_mode(): group = _get_default_group() if group is None else group task = group.process_group.recv(tensor, src) if use_calc_stream: diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index b90f24d377..d9d252024d 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -24,19 +24,20 @@ from paddle import compat as cpt # deprecated module import from paddle.fluid import core -import paddle.fluid.framework as framework +from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import _set_expected_place from paddle.fluid.dygraph import parallel_helper from paddle.distributed.fleet.launch_utils import check_backend from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.distributed.fleet.base.private_helper_function import wait_server_ready # noqa: F401 -import paddle.distributed.collective as collective +from paddle.distributed.collective import _set_group_map +from paddle.distributed.collective import _set_group_map_by_name +from paddle.distributed.collective import _get_group_map_by_name from paddle.distributed.collective import _group_map_by_name -from paddle.distributed.collective import _group_map from paddle.distributed.collective import _default_group_name from paddle.distributed.collective import _valid_backend_list -from paddle.distributed.collective import _default_backend -from paddle.distributed.collective import _default_store +from paddle.distributed.collective import _set_default_backend +from paddle.distributed.collective import _set_default_store from paddle.distributed.collective import _new_process_group_impl from paddle.distributed.collective import Group @@ -205,10 +206,10 @@ def init_parallel_env(): _set_expected_place(place) group = None - if backend in _valid_backend_list and framework._in_eager_mode_: - if _default_group_name in collective._group_map_by_name: - return collective._group_map_by_name[_default_group_name] - _default_backend = backend + if backend in _valid_backend_list and in_dygraph_mode(): + if _default_group_name in _get_group_map_by_name(): + return _get_group_map_by_name()[_default_group_name] + _set_default_backend(backend) rank = int(os.getenv("PADDLE_TRAINER_ID")) world_size = int(os.getenv("PADDLE_TRAINERS_NUM")) assert rank >= 0 and world_size > rank and world_size > 1, ( @@ -230,11 +231,12 @@ def init_parallel_env(): master_addr, master_port = endpoints.split(":") master_port = int(master_port) is_master = rank == 0 - _default_store = core.TCPStore(master_addr, master_port, is_master, - world_size) + default_store = core.TCPStore(master_addr, master_port, is_master, + world_size) + _set_default_store(default_store) pg = _new_process_group_impl( backend, - _default_store, + default_store, rank, world_size, _default_group_name, @@ -247,8 +249,8 @@ def init_parallel_env(): ranks=ranks, pg=pg, name=_default_group_name) - collective._group_map_by_name[_default_group_name] = group - _group_map[0] = group + _set_group_map_by_name(_default_group_name, group) + _set_group_map(0, group) parallel_helper._set_parallel_ctx(True) return group diff --git a/python/paddle/fluid/tests/unittests/process_group_gloo.py b/python/paddle/fluid/tests/unittests/process_group_gloo.py index b1f3a71ab3..03886ab8a1 100644 --- a/python/paddle/fluid/tests/unittests/process_group_gloo.py +++ b/python/paddle/fluid/tests/unittests/process_group_gloo.py @@ -45,7 +45,7 @@ class TestProcessGroupFp32(unittest.TestCase): nranks = ParallelEnv().nranks rank = ParallelEnv().local_rank is_master = True if rank == 0 else False - store = paddle.fluid.core.TCPStore("127.0.0.1", 6172, is_master, + store = paddle.fluid.core.TCPStore("127.0.0.1", 6272, is_master, nranks, datetime.timedelta(0)) pg = paddle.fluid.core.ProcessGroupGloo(store, rank, nranks) -- GitLab