# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import paddle # (TODO: GhostScreaming) It will be removed later. import paddle.fluid.core as core from paddle.framework import _non_static_mode, in_dygraph_mode from .communication.group import Group, _add_new_group, is_initialized from .fleet.layers.mpu.mp_ops import _c_concat # noqa: F401 from .fleet.layers.mpu.mp_ops import _c_identity # noqa: F401 from .fleet.layers.mpu.mp_ops import _c_lookup_table # noqa: F401 from .fleet.layers.mpu.mp_ops import _c_softmax_with_cross_entropy # noqa: F401 from .fleet.layers.mpu.mp_ops import _c_split # noqa: F401 from .fleet.layers.mpu.mp_ops import _Linear # noqa: F401 from .fleet.layers.mpu.mp_ops import _linear # noqa: F401 from .fleet.layers.mpu.mp_ops import _mp_allreduce # noqa: F401 from .fleet.layers.mpu.mp_ops import _parallel_embedding # noqa: F401 from .fleet.layers.mpu.mp_ops import _parallel_linear # noqa: F401 from .fleet.layers.mpu.mp_ops import _set_var_distributed # noqa: F401 from .fleet.layers.mpu.mp_ops import split # noqa: F401 __all__ = [] _global_env = None def _get_global_env(): global _global_env if not _global_env: _global_env = paddle.distributed.ParallelEnv() return _global_env # group map : the map of all group, 0 for GlobalGroup # Dict[int, Group] _group_map = {} _global_env_gid = 0 # group map by name : the map of all groups from their names # Dict[name, Group] _group_map_by_name = {} # backend map by group : the map of all backend from their groups # Dict[group, backend] _group_map_backend = {} # Name of the default group for init_parallel_env _default_group_name = "_default_pg" _valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl', 'bkcl'] _default_store = None # the default tcp store _default_backend = None _default_timeout = datetime.timedelta(seconds=1800) _start_ring_id = 0 def _set_default_backend(backend): global _default_backend _default_backend = backend def _set_default_store(store): global _default_store _default_store = store def _get_group_map(): global _group_map if _global_env_gid not in _group_map: genv = _get_global_env() _group_map[_global_env_gid] = Group( genv.rank, 0, list(range(genv.world_size)) ) return _group_map def _get_global_group(): return _get_group_map()[_global_env_gid] def _get_group_map_by_name(): global _group_map_by_name return _group_map_by_name def _get_default_group(): global _group_map_by_name assert is_initialized(), ( "Call paddle.distributed.init_parallel_env first " "to initialize the distributed environment." ) return _get_group_map_by_name()[_default_group_name] def _set_group_map(gid, group): global _group_map assert gid not in _group_map _group_map[gid] = group def _set_group_map_by_name(name, group): global _group_map_by_name assert name not in _group_map_by_name _group_map_by_name[name] = group def _set_group_map_backend(group, backend): global _group_map_backend assert group not in _group_map_backend _group_map_backend[group] = backend def _new_ring_id(): # NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax. if in_dygraph_mode(): global _start_ring_id _start_ring_id += 1 return _start_ring_id + max(_get_global_env().nrings, 9) else: return len(_get_group_map()) + max(_get_global_env().nrings, 9) def _new_process_group_impl( backend, store, rank, world_size, group_name, pg_options, group_id=0, ): pg = None genv = _get_global_env() assert backend in _valid_backend_list, "Unsupported backend: %s." % backend if backend == "gloo": pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id) elif backend == "nccl": pg = core.ProcessGroupNCCL.create(store, rank, world_size, group_id) elif backend == "xccl": pg = core.ProcessGroupCustom.create( store, genv.device_type, rank, world_size, group_id ) elif backend == "bkcl": pg = core.ProcessGroupBKCL.create(store, rank, world_size, group_id) return pg # _custom_gid provides a way for users to # set the group id, which is usually useful # to be compatible with the static mode. _custom_gid = None def _set_custom_gid(gid): global _custom_gid _custom_gid = gid def new_group(ranks=None, backend=None, timeout=_default_timeout): """ Creates a new distributed communication group. Args: ranks (list): The global ranks of group members. backend (str): The backend used to create group, only nccl is supported now. timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes. Returns: Group: The group instance. Examples: .. code-block:: python import paddle paddle.distributed.init_parallel_env() tindata = paddle.randn(shape=[2, 3]) gp = paddle.distributed.new_group([2,4,6]) paddle.distributed.all_reduce(tindata, group=gp, sync_op=False) """ global _custom_gid global _group_map if in_dygraph_mode(): global _default_group_name gid = _custom_gid if _custom_gid else _new_ring_id() group_name = _default_group_name + str(gid) if backend != 'heter' and (ranks is None or len(ranks) > 1): global_group = _get_default_group() global_rank = global_group.rank global_ranks = global_group.ranks backend = _default_backend if backend is None else backend if ranks is None: ranks = global_ranks assert len(ranks) <= len(global_ranks), ( "Size of new group must be less than or " "equal to that of the default global group." ) size = len(ranks) ranks = sorted(ranks) if size > 1 and global_rank in ranks: rank = 0 if backend == 'heter' else ranks.index(global_rank) pg = _new_process_group_impl( backend, _default_store, rank, size, group_name, pg_options=None, group_id=gid, ) else: rank = -1 pg = None group = Group(rank, gid, ranks, pg=pg, name=group_name) _group_map_by_name[group_name] = group _group_map[gid] = group _group_map_backend[group] = backend # TODO: The method below is a new method for group management, will replace the previous # three in the future. _add_new_group(group) # TODO(shenliang03): This is a temporary solution to solve the problem of # hang caused by tcp paddle.distributed.barrier(group=group) if paddle.distributed.get_world_size() > 1: paddle.distributed.barrier() return group if not backend: backend = 'nccl' assert backend == 'nccl', "backend other than nccl is not supported yet" genv = _get_global_env() global_rank = genv.rank ring_id = _new_ring_id() if global_rank not in ranks: gp = Group(-1, ring_id, ranks) _group_map[ring_id] = gp else: ranks = sorted(ranks) group_rank = ranks.index(global_rank) group_size = len(ranks) gp = Group(group_rank, ring_id, ranks) _group_map[ring_id] = gp if group_size >= 2: strategy = core.ParallelStrategy() strategy.nranks = group_size strategy.local_rank = group_rank strategy.trainer_endpoints = [ genv.trainer_endpoints[i] for i in ranks ] strategy.current_endpoint = genv.current_endpoint strategy.nrings = 1 if core.is_compiled_with_cuda(): place = core.CUDAPlace(genv.device_id) core.NCCLParallelContext(strategy, place).init_with_ring_id( ring_id ) elif core.is_compiled_with_npu(): place = core.NPUPlace(genv.device_id) core.HCCLParallelContext(strategy, place).init_with_ring_id( ring_id ) elif core.is_compiled_with_mlu(): place = core.MLUPlace(genv.device_id) core.CNCLParallelContext(strategy, place).init_with_ring_id( ring_id ) elif core.is_compiled_with_xpu(): place = core.XPUPlace(genv.device_id) core.BKCLParallelContext(strategy, place).init_with_ring_id( ring_id ) else: assert False, "no cuda device found" else: return gp # TODO(shenliang03): This is a temporary solution to solve the problem of # hang caused by cross-creation of new_group tmp = ( paddle.to_tensor([1], dtype="int32") if _non_static_mode() else paddle.full([0], 1, dtype="int32") ) paddle.distributed.all_reduce(tmp, sync_op=True) paddle.distributed.wait(tmp) return gp