topology.py 8.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function
import sys
17 18 19 20 21
import paddle
import collections
import numpy as np
from itertools import product
from functools import reduce
22 23
from ..utils.log_util import logger

24 25
__all__ = ['CommunicateTopology', 'HybridCommunicateGroup']

26 27
_HYBRID_PARALLEL_GROUP = None

28

29 30
class ParallelMode(object):
    DATA_PARALLEL = 0
31
    TENSOR_PARALLEL = 1
32 33 34
    PIPELINE_PARALLEL = 2


35
class CommunicateTopology(object):
36 37 38
    def __init__(self,
                 hybrid_group_names=["data", "pipe", "model"],
                 dims=[1, 1, 1]):
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
        self._parallel_names = hybrid_group_names
        self._dims = dims
        self.coordinate = collections.namedtuple('Coordinate',
                                                 self._parallel_names)
        self._world_size = reduce(lambda x, y: x * y, self._dims)

        ranges = [range(d) for d in self._dims]
        all_coordinate = [self.coordinate(*x) for x in product(*ranges)]

        self._coord2rank = dict(zip(all_coordinate, range(len(all_coordinate))))
        self._rank2coord = dict(
            zip(self._coord2rank.values(), self._coord2rank.keys()))

    def get_hybrid_group_names(self):
        return self._parallel_names

    def get_dim(self, axis_name):
        return self._dims[self._parallel_names.index(axis_name)]

    def world_size(self):
        return self._world_size

    def get_rank(self, **args):
        assert len(args) == len(self._dims)
        key = self.coordinate(**args)
        assert key in self._coord2rank.keys()
        return self._coord2rank[key]

    def get_coord(self, rank):
        assert rank < self._world_size
        assert rank in self._rank2coord.keys()
        return self._rank2coord[rank]

    def get_axis_list(self, axis_name, index):
        axis = self._parallel_names.index(axis_name)
        ranks = [
            self._coord2rank[coord] for coord in self._coord2rank.keys()
            if coord[axis] == index
        ]
        ranks.sort()
        return ranks

    def get_dim_size(self, axis_name):
        assert axis_name in self._parallel_names
        return self._dims[self._parallel_names.index(axis_name)]

    def get_comm_list(self, axis_name):
        assert axis_name in self._parallel_names
        other_axis_names = [
            name for name in self._parallel_names if name != axis_name
        ]

        ranges = []
        for name in other_axis_names:
            dim_num = self.get_dim_size(name)
            ranges.append(range(dim_num))

        all_result = []
        for x in product(*ranges):
            key_coord = {}
            for other_name in other_axis_names:
                key_coord[other_name] = x[other_axis_names.index(other_name)]

            result = []
            for i in range(0, self.get_dim_size(axis_name)):
                key_coord[axis_name] = i
                result.append(self._coord2rank[self.coordinate(**key_coord)])
            all_result.append(result)

        return all_result


class HybridCommunicateGroup(object):
    def __init__(self, topology):
        self.nranks = paddle.distributed.get_world_size()
        self.global_rank = paddle.distributed.get_rank()
        self._topo = topology

117 118 119
        self._dp_degree = self._topo.get_dim('data')
        self._mp_degree = self._topo.get_dim('model')
        self._pp_degree = self._topo.get_dim('pipe')
120 121 122

        self._data_parallel_id = self._get_data_parallel_id()
        self._model_parallel_id = self._get_model_parallel_id()
123
        self.stage_id = self._get_pipe_parallel_id()
124 125

        assert self._check_vaild_topo(
126 127 128
        ), "Here is an unreasonable topogy setting. world_size: {}, but" \
            "dp_num: {}, mp_num: {}, pp_num: {}".format(self.nranks, self._dp_degree,
            self._mp_degree, self._pp_degree)
129 130 131 132 133 134

        # create comm group for data parallel
        self._dp_group, self._dp_comm_group = self._set_comm_group("data")

        # create comm group for model parallel
        self._mp_group, self._mp_comm_group = self._set_comm_group("model")
135

136 137 138
        # create comm group for pipe parallel
        self._pp_group, self._pp_comm_group = self._set_comm_group("pipe")

139 140 141 142
        # create global group for check inf_nan / clip global norm
        self._check_group, self._check_comm_group = self._set_check_group(
            "data")

143 144 145 146
        # create p2p group
        self.is_first_stage = (self.stage_id == 0)
        self.is_last_stage = (self.stage_id == (self._pp_degree - 1))

147
        debug_str = "HybridParallelInfo: rank_id: %d, dp_degree: %d, " \
148
                    "mp_degree: %d, pp_degree: %d" % (self.global_rank, self._dp_degree,
149
                    self._mp_degree,self._pp_degree)
L
lilong12 已提交
150
        debug_str += ", dp_group: %s, mp_group: %s, pp_group: %s, check/clip group: %s" % (
151
            self._dp_group, self._mp_group, self._pp_group, self._check_group)
152
        logger.info(debug_str)
153 154 155

        global _HYBRID_PARALLEL_GROUP
        _HYBRID_PARALLEL_GROUP = self
156

157
    def get_parallel_mode(self):
158
        # there are three modes : DataParallel / TensorParallel / PipelineParallel
159 160 161 162
        if self._mp_degree == 1 and self._pp_degree == 1:
            return ParallelMode.DATA_PARALLEL
        elif self._mp_degree > 1 and self._pp_degree == 1:
            # initialize the seed
163
            return ParallelMode.TENSOR_PARALLEL
164 165 166
        elif self._pp_degree > 1:
            return ParallelMode.PIPELINE_PARALLEL

167
    def _check_vaild_topo(self):
168
        return self._dp_degree * self._mp_degree * self._pp_degree == self.nranks
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185

    def _set_comm_group(self, parallel_method="data"):
        parallel_group = []
        parallel_comm_group = None
        parallel_groups = self._topo.get_comm_list(parallel_method)

        for group in parallel_groups:
            comm_group = paddle.distributed.new_group(ranks=group)
            if self.global_rank in group:
                parallel_group = group
                parallel_comm_group = comm_group

        assert len(parallel_group) > 0
        assert parallel_comm_group is not None

        return parallel_group, parallel_comm_group

186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
    def _set_check_group(self, parallel_method="data"):
        parallel_group = []
        parallel_comm_group = None
        parallel_size = self._topo.get_dim(parallel_method)
        for idx in range(parallel_size):
            parallel_groups = self._topo.get_axis_list(parallel_method, idx)
            comm_group = paddle.distributed.new_group(ranks=parallel_groups)
            if self.global_rank in parallel_groups:
                parallel_group = parallel_groups
                parallel_comm_group = comm_group

        assert len(parallel_group) > 0
        assert parallel_comm_group is not None

        return parallel_group, parallel_comm_group

202 203 204 205 206 207 208 209 210 211 212 213 214 215
    def topology(self):
        return self._topo

    def get_global_rank(self):
        return self.global_rank

    # data parallel message:
    def _get_data_parallel_id(self):
        return self._topo.get_coord(self.global_rank).data

    def get_data_parallel_rank(self):
        return self._data_parallel_id

    def get_data_parallel_world_size(self):
216
        return self._dp_degree
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231

    def get_data_parallel_group(self):
        return self._dp_comm_group

    def get_data_parallel_group_src_rank(self):
        return self._dp_comm_group.ranks[0]

    # model parallel message:
    def _get_model_parallel_id(self):
        return self._topo.get_coord(self.global_rank).model

    def get_model_parallel_rank(self):
        return self._model_parallel_id

    def get_model_parallel_world_size(self):
232
        return self._mp_degree
233 234 235 236 237 238

    def get_model_parallel_group(self):
        return self._mp_comm_group

    def get_model_parallel_group_src_rank(self):
        return self._mp_comm_group.ranks[0]
239

240 241 242 243 244 245 246 247 248 249 250 251 252
    # pipeline parallel message
    def _get_pipe_parallel_id(self):
        return self._topo.get_coord(self.global_rank).pipe

    def get_stage_id(self):
        return self.stage_id

    def get_pipe_parallel_world_size(self):
        return self._pp_degree

    def get_pipe_parallel_group(self):
        return self._pp_comm_group

253 254 255
    # check parallel group
    def get_check_parallel_group(self):
        return self._check_comm_group
256 257 258 259 260

    def get_rank_from_stage(self, stage_id):
        coord = self._topo.get_coord(self.global_rank)
        tf = coord._replace(pipe=stage_id)._asdict()
        return self._topo.get_rank(**tf)