collective.py 9.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import datetime
16

17
import paddle
18 19

# (TODO: GhostScreaming) It will be removed later.
20
import paddle.fluid.core as core
21
from paddle.framework import _non_static_mode, in_dygraph_mode
22 23

from .communication.group import Group, _add_new_group, is_initialized
24
from .fleet.layers.mpu.mp_ops import _c_concat  # noqa: F401
25
from .fleet.layers.mpu.mp_ops import _c_identity  # noqa: F401
26 27
from .fleet.layers.mpu.mp_ops import _c_lookup_table  # noqa: F401
from .fleet.layers.mpu.mp_ops import _c_softmax_with_cross_entropy  # noqa: F401
28 29
from .fleet.layers.mpu.mp_ops import _c_split  # noqa: F401
from .fleet.layers.mpu.mp_ops import _Linear  # noqa: F401
30
from .fleet.layers.mpu.mp_ops import _linear  # noqa: F401
31
from .fleet.layers.mpu.mp_ops import _mp_allreduce  # noqa: F401
32
from .fleet.layers.mpu.mp_ops import _parallel_embedding  # noqa: F401
33 34 35
from .fleet.layers.mpu.mp_ops import _parallel_linear  # noqa: F401
from .fleet.layers.mpu.mp_ops import _set_var_distributed  # noqa: F401
from .fleet.layers.mpu.mp_ops import split  # noqa: F401
36

37
__all__ = []
38

K
kuizhiqing 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51
_global_env = None


def _get_global_env():
    global _global_env
    if not _global_env:
        _global_env = paddle.distributed.ParallelEnv()
    return _global_env


# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}
52
_global_env_gid = 0
K
kuizhiqing 已提交
53

54 55 56 57
# group map by name : the map of all groups from their names
# Dict[name, Group]
_group_map_by_name = {}

58 59 60 61
# backend map by group : the map of all backend from their groups
# Dict[group, backend]
_group_map_backend = {}

62 63 64
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

J
james 已提交
65
_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl', 'bkcl']
66 67
_default_store = None  # the default tcp store
_default_backend = None
68 69
_default_timeout = datetime.timedelta(seconds=1800)
_start_ring_id = 0
70

K
kuizhiqing 已提交
71

L
lilong12 已提交
72 73 74 75 76 77 78 79 80 81
def _set_default_backend(backend):
    global _default_backend
    _default_backend = backend


def _set_default_store(store):
    global _default_store
    _default_store = store


K
kuizhiqing 已提交
82 83
def _get_group_map():
    global _group_map
84
    if _global_env_gid not in _group_map:
K
kuizhiqing 已提交
85
        genv = _get_global_env()
86 87 88
        _group_map[_global_env_gid] = Group(
            genv.rank, 0, list(range(genv.world_size))
        )
K
kuizhiqing 已提交
89 90 91 92
    return _group_map


def _get_global_group():
93
    return _get_group_map()[_global_env_gid]
K
kuizhiqing 已提交
94 95


96 97 98 99 100 101
def _get_group_map_by_name():
    global _group_map_by_name
    return _group_map_by_name


def _get_default_group():
L
lilong12 已提交
102
    global _group_map_by_name
103 104 105 106
    assert is_initialized(), (
        "Call paddle.distributed.init_parallel_env first "
        "to initialize the distributed environment."
    )
107 108 109
    return _get_group_map_by_name()[_default_group_name]


L
lilong12 已提交
110 111 112 113 114 115 116 117 118 119 120 121
def _set_group_map(gid, group):
    global _group_map
    assert gid not in _group_map
    _group_map[gid] = group


def _set_group_map_by_name(name, group):
    global _group_map_by_name
    assert name not in _group_map_by_name
    _group_map_by_name[name] = group


122 123 124 125 126 127
def _set_group_map_backend(group, backend):
    global _group_map_backend
    assert group not in _group_map_backend
    _group_map_backend[group] = backend


K
kuizhiqing 已提交
128
def _new_ring_id():
129 130 131 132 133 134 135
    # NOTE(liyurui): For compatible reason, auto parallel and eager mode relay on previous syntax.
    if in_dygraph_mode():
        global _start_ring_id
        _start_ring_id += 1
        return _start_ring_id + max(_get_global_env().nrings, 9)
    else:
        return len(_get_group_map()) + max(_get_global_env().nrings, 9)
K
kuizhiqing 已提交
136 137


138 139 140 141 142 143 144 145 146
def _new_process_group_impl(
    backend,
    store,
    rank,
    world_size,
    group_name,
    pg_options,
    group_id=0,
):
147
    pg = None
148
    genv = _get_global_env()
L
lilong12 已提交
149
    assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
150
    if backend == "gloo":
L
LiYuRio 已提交
151
        pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id)
152
    elif backend == "nccl":
L
LiYuRio 已提交
153
        pg = core.ProcessGroupNCCL.create(store, rank, world_size, group_id)
154
    elif backend == "xccl":
L
LiYuRio 已提交
155
        pg = core.ProcessGroupCustom.create(
156 157
            store, genv.device_type, rank, world_size, group_id
        )
J
james 已提交
158
    elif backend == "bkcl":
L
LiYuRio 已提交
159
        pg = core.ProcessGroupBKCL.create(store, rank, world_size, group_id)
160 161 162
    return pg


L
lilong12 已提交
163 164 165 166 167 168 169
# _custom_gid provides a way for users to
# set the group id, which is usually useful
# to be compatible with the static mode.
_custom_gid = None


def _set_custom_gid(gid):
170
    global _custom_gid
L
lilong12 已提交
171 172 173
    _custom_gid = gid


174
def new_group(ranks=None, backend=None, timeout=_default_timeout):
K
kuizhiqing 已提交
175 176
    """

K
kuizhiqing 已提交
177
    Creates a new distributed communication group.
K
kuizhiqing 已提交
178 179

    Args:
K
kuizhiqing 已提交
180
        ranks (list): The global ranks of group members.
K
kuizhiqing 已提交
181
        backend (str): The backend used to create group, only nccl is supported now.
182
        timeout (datetime.timedelta, optional): The waiting timeout for store relevant options, default is 30 minutes.
K
kuizhiqing 已提交
183 184

    Returns:
K
kuizhiqing 已提交
185
        Group: The group instance.
K
kuizhiqing 已提交
186 187 188 189 190 191 192

    Examples:
        .. code-block:: python

            import paddle

            paddle.distributed.init_parallel_env()
K
kuizhiqing 已提交
193 194
            tindata = paddle.randn(shape=[2, 3])
            gp = paddle.distributed.new_group([2,4,6])
195
            paddle.distributed.all_reduce(tindata, group=gp, sync_op=False)
K
kuizhiqing 已提交
196 197

    """
198
    global _custom_gid
199
    global _group_map
L
lilong12 已提交
200
    if in_dygraph_mode():
201
        global _default_group_name
L
lilong12 已提交
202
        gid = _custom_gid if _custom_gid else _new_ring_id()
203
        group_name = _default_group_name + str(gid)
L
lilong12 已提交
204
        if backend != 'heter' and (ranks is None or len(ranks) > 1):
205 206 207 208 209 210 211 212
            global_group = _get_default_group()
            global_rank = global_group.rank
            global_ranks = global_group.ranks
            backend = _default_backend if backend is None else backend
            if ranks is None:
                ranks = global_ranks
            assert len(ranks) <= len(global_ranks), (
                "Size of new group must be less than or "
213 214
                "equal to that of the default global group."
            )
215 216
        size = len(ranks)
        ranks = sorted(ranks)
L
LiYuRio 已提交
217
        if size > 1 and global_rank in ranks:
L
lilong12 已提交
218
            rank = 0 if backend == 'heter' else ranks.index(global_rank)
219 220 221 222 223 224 225 226 227
            pg = _new_process_group_impl(
                backend,
                _default_store,
                rank,
                size,
                group_name,
                pg_options=None,
                group_id=gid,
            )
228 229 230
        else:
            rank = -1
            pg = None
231
        group = Group(rank, gid, ranks, pg=pg, name=group_name)
232 233
        _group_map_by_name[group_name] = group
        _group_map[gid] = group
234
        _group_map_backend[group] = backend
235
        # TODO: The method below is a new method for group management, will replace the previous
236 237
        # three in the future.
        _add_new_group(group)
238

239
        # TODO(shenliang03): This is a temporary solution to solve the problem of
240
        # hang caused by tcp
241
        paddle.distributed.barrier(group=group)
L
LiYuRio 已提交
242 243
        if paddle.distributed.get_world_size() > 1:
            paddle.distributed.barrier()
244
        return group
K
kuizhiqing 已提交
245 246 247

    if not backend:
        backend = 'nccl'
248
    assert backend == 'nccl', "backend other than nccl is not supported yet"
K
kuizhiqing 已提交
249 250 251 252 253 254 255

    genv = _get_global_env()
    global_rank = genv.rank

    ring_id = _new_ring_id()

    if global_rank not in ranks:
256
        gp = Group(-1, ring_id, ranks)
K
kuizhiqing 已提交
257 258
        _group_map[ring_id] = gp
    else:
259 260 261
        ranks = sorted(ranks)
        group_rank = ranks.index(global_rank)
        group_size = len(ranks)
262
        gp = Group(group_rank, ring_id, ranks)
263 264 265 266 267 268 269 270 271 272 273 274 275 276
        _group_map[ring_id] = gp

        if group_size >= 2:
            strategy = core.ParallelStrategy()
            strategy.nranks = group_size
            strategy.local_rank = group_rank
            strategy.trainer_endpoints = [
                genv.trainer_endpoints[i] for i in ranks
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
277 278 279
                core.NCCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
280 281
            elif core.is_compiled_with_npu():
                place = core.NPUPlace(genv.device_id)
282 283 284
                core.HCCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
285 286
            elif core.is_compiled_with_mlu():
                place = core.MLUPlace(genv.device_id)
287 288 289
                core.CNCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
290 291
            elif core.is_compiled_with_xpu():
                place = core.XPUPlace(genv.device_id)
292 293 294
                core.BKCLParallelContext(strategy, place).init_with_ring_id(
                    ring_id
                )
295
            else:
296
                assert False, "no cuda device found"
297 298 299
        else:
            return gp

300
    # TODO(shenliang03): This is a temporary solution to solve the problem of
301
    # hang caused by cross-creation of new_group
302 303 304
    tmp = (
        paddle.to_tensor([1], dtype="int32")
        if _non_static_mode()
305
        else paddle.full([0], 1, dtype="int32")
306
    )
307
    paddle.distributed.all_reduce(tmp, sync_op=True)
308
    paddle.distributed.wait(tmp)
K
kuizhiqing 已提交
309
    return gp