process_group.py 6.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15 16
from collections import OrderedDict

17 18
import paddle
import paddle.fluid.core as core
19

20 21
from ..collective import _get_global_env
from ..collective import _new_ring_id
J
Jiabin Yang 已提交
22
from ...fluid.framework import _non_static_mode
23
from ...fluid.layers.tensor import fill_constant
24
from paddle.fluid.framework import _enable_legacy_dygraph
25 26 27


def get_all_process_groups():
28 29
    global _g_process_group_map
    return _g_process_group_map.values()
30 31


32
def get_process_group(group_id, g_process_group_map=None):
33
    global _g_process_group_map
34 35 36 37
    return _g_process_group_map.get(
        group_id,
        None) if g_process_group_map is None else g_process_group_map.get(
            group_id, None)
38 39


J
JZ-LIANG 已提交
40
def get_world_process_group():
41 42 43 44
    global _g_process_group_map
    return _g_process_group_map[0]


45 46 47 48 49 50 51
def clear_all_process_groups():
    global _g_process_group_map
    _g_process_group_map = {}
    _g_process_group_map[0] = ProcessGroup(0, [])


def new_process_group(ranks, group_id=None):
52
    global _g_process_group_map
53
    # A key constructed from ranks is used for avoiding duplication
54 55 56 57 58 59 60 61 62
    new_key = ''.join(map(str, sorted(ranks)))
    for pg_id, pg in _g_process_group_map.items():
        cur_key = ''.join(map(str, sorted(pg.ranks)))
        if pg_id != 0 and new_key == cur_key:
            return pg
    # If not matching the existing one, construt a new process group
    num_groups = len(_g_process_group_map)
    # Note: our process group may interfere with the original implementation
    # so the created group id should start from the original _new_ring_id()
63 64 65
    if group_id == None:
        group_id = _new_ring_id() + num_groups + 1

66 67 68
    new_pg = ProcessGroup(group_id, ranks)
    _g_process_group_map[group_id] = new_pg
    return new_pg
69 70 71


# This implementation refers to lots of Paddle/python/paddle/distributed/collective.py,
72
# Fleet also has a collective helper which uses ops to initialize communication in
73
# Paddle/python/paddle/distributed/fleet/meta_optimizers/common.py. We use the first one
74 75
# because it seems simple. This should be enhanced to manage the process membership and
# the instantiation process in a more general way. In the future, the process group may
76 77
# handle the communication implementation choice.
class ProcessGroup:
78

79
    def __init__(self, group_id, ranks):
80 81
        if group_id == 0 and get_process_group(0) is not None:
            assert group_id != 0, "Process group id 0 is reserved for all ranks."
82 83
        self._group_id = group_id
        self._ranks = sorted(ranks)
84 85 86 87
        # Add the current ranks into group 0
        if group_id != 0:
            global _g_process_group_map
            _g_process_group_map[0].add_ranks(ranks)
88 89 90 91 92 93
        self._is_instantiate = False

    @property
    def id(self):
        return self._group_id

94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
    @property
    def ranks(self):
        return self._ranks

    @property
    def nranks(self):
        return len(self._ranks)

    def add_ranks(self, new_ranks):
        if set(new_ranks) <= set(self.ranks):
            return
        else:
            assert self.is_instantiate() == False, \
                "Cannot add new ranks after instantiating the process group"
        self._ranks.extend(new_ranks)
        self._ranks = sorted(list(set(self.ranks)))
110 111

    def local_rank(self, global_rank):
112 113
        if global_rank in self.ranks:
            return self.ranks.index(global_rank)
114 115 116 117 118 119 120 121 122 123 124 125 126 127
        else:
            assert False, \
                "Rank {} doesn't belong to this group".format(global_rank)

    def is_instantiate(self):
        return self._is_instantiate

    def instantiate(self):
        if self._is_instantiate:
            return
        ring_id = self.id
        genv = _get_global_env()
        global_rank = genv.rank

128
        if self.nranks >= 2:
129
            strategy = core.ParallelStrategy()
130
            strategy.nranks = self.nranks
131 132
            strategy.local_rank = self.local_rank(global_rank)
            strategy.trainer_endpoints = [
133
                genv.trainer_endpoints[i] for i in self.ranks
134 135 136 137 138 139 140 141 142 143 144
            ]
            strategy.current_endpoint = genv.current_endpoint
            strategy.nrings = 1

            if core.is_compiled_with_cuda():
                place = core.CUDAPlace(genv.device_id)
                core.NCCLParallelContext(strategy,
                                         place).init_with_ring_id(ring_id)
            else:
                assert False, ("No CUDA device found")

145 146
            # TODO(shenliang03): This is a temporary solution to solve the problem of
            # hang caused by cross-creation of new_group
147 148
            paddle.disable_static()
            _enable_legacy_dygraph()
149 150 151 152 153 154 155 156
            paddle.set_device('gpu:%d' %
                              paddle.distributed.ParallelEnv().dev_id)
            tmp = paddle.to_tensor(
                [1], dtype="int32") if _non_static_mode() else fill_constant(
                    [0], dtype="int32", value="1")
            paddle.distributed.all_reduce(tmp, use_calc_stream=True, group=self)
            paddle.distributed.wait(tmp, group=self)
            paddle.enable_static()
157 158 159

        self._is_instantiate = True

160 161 162
    def is_member(self):
        return True

163 164 165 166 167 168 169 170 171 172
    # def __eq__(self, other):
    #     if not isinstance(other, ProcessGroup):
    #         return False
    #     if self.id != other.id:
    #         return False
    #     return True

    # def __ne__(self, other):
    #     return not self.__eq__(other)

173 174
    def __str__(self):
        string = "id: {}, nranks: {}, ranks: {}.".format(
175
            self.id, self.nranks, ", ".join(map(str, self.ranks)))
176
        return string
177 178


179
# Note that Process group 0 is reserved for representing all ranks.
180
# At the beginning, group 0 is empty and new ranks will be added automatically.
181
_g_process_group_map = OrderedDict()
182
_g_process_group_map[0] = ProcessGroup(0, [])