orthogonal_strategy.py 7.4 KB
Newer Older
L
LiYuRio 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import collections
import functools
import paddle.distributed as dist
from paddle.distributed.fleet.base.strategy_group import StrategyGroupBase


22
class OrthogonalStrategy:
L
LiYuRio 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
    """
    A hybrid of multiple distributed strategies. Strategies need to be orthogonal, means the ranks are organized like
    a square if there are two strategies, a cube if there aree three strategies, etc.

    Args:
        list_of_strategy(list): Stategy in the list should be represented as tuple, format as (strategy_name, degree, strategy_class).
        fused_strategy_dict(dict, optional): Exist strategies can be fused to new strategy. Use the name of new strategy as key, a list of
            strategy names you want to fuse as value.

    Returns:
        The instance of strategy.

    Examples:
        .. code-block:: python

        # required: distributed
        import paddle
        import paddle.distributed as dist
        from paddle.distributed.fleet.base.strategy_group import DPGroup, MPGroup, PPGroup
        from paddle.distributed.fleet.base.orthogonal_strategy import OrthogonalStrategy

        dist.init_parallel_env()
        strategy = OrthogonalStrategy([("dp", 2, DPGroup), ("mp", 2, MPGroup), ("pp", 2, PPGroup)], fused_strategy_dict={"check": ["mp", "pp"]})

    """

    def __init__(self, list_of_strategy, fused_strategy_dict={}):
        self._list_of_strategy = list_of_strategy
        self._fused_strategy_dict = fused_strategy_dict
        self._rank = dist.get_rank()
        self._rank_list_dict = {}
        self._name_to_group_dict = {}
        self._name_to_degree_dict = {}
        self._list_of_strategy_name = [
            strategy[0] for strategy in list_of_strategy
        ]
        self._list_of_degree = [strategy[1] for strategy in list_of_strategy]
60 61 62
        self._coordinate = collections.namedtuple(
            'Coordinate', self._list_of_strategy_name
        )
L
LiYuRio 已提交
63 64 65 66 67 68 69
        self._check_valid_strategy()

        ranges = [range(degree) for degree in self._list_of_degree]
        list_of_coord = [
            self._coordinate(*coord) for coord in itertools.product(*ranges)
        ]
        self._coord_to_rank_dict = dict(
70 71
            zip(list_of_coord, range(len(list_of_coord)))
        )
L
LiYuRio 已提交
72 73 74 75 76 77

        for idx, strategy in enumerate(list_of_strategy):
            strategy_name = strategy[0]
            self._name_to_degree_dict[strategy_name] = strategy[1]
            self._rank_list_dict[strategy_name] = self._calc_rank_list(idx)
            self._name_to_group_dict[strategy_name] = strategy[2](
78 79
                self._rank_list_dict[strategy_name]
            )
L
LiYuRio 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93

        self._name_to_fused_group_dict = {}
        self._create_fused_group()

    def strategy_group(self, name):
        """
        Get strategy group with specific name.

        Args:
            name: The name of strategy group

        Returns:
            An instance of specific strategy group.
        """
94 95 96
        assert (
            name in self._list_of_strategy_name
        ), "Strategy group {} is not created.".format(name)
L
LiYuRio 已提交
97 98 99 100 101 102 103 104 105 106 107 108
        return self._name_to_group_dict[name]

    def fused_strategy_group(self, name):
        """
        Get fused strategy group with specific name.

        Args:
            name: The name of fused strategy group

        Returns:
            (StrategyGroupBase): An instance of strategy group.
        """
109 110 111
        assert (
            name in self._name_to_fused_group_dict
        ), "Fused strategy group {} is not created.".format(name)
L
LiYuRio 已提交
112 113 114 115 116 117 118 119 120 121 122 123
        return self._name_to_fused_group_dict[name]

    def rank_in_strategy(self, name):
        """
        Get local rank in strategy group with specific name.

        Args:
            name: The name of strategy group

        Returns:
            (Integer): Local rank in specific strategy.
        """
124 125 126
        assert (
            name in self._list_of_strategy_name
        ), "Strategy group {} is not created.".format(name)
L
LiYuRio 已提交
127 128 129 130 131
        return self._name_to_group_dict[name].group.rank

    def _check_valid_strategy(self):
        assert len(self._list_of_strategy_name) == len(
            set(self._list_of_strategy_name)
132 133 134
        ), "Defined duplicated strategies: {}".format(
            self._list_of_strategy_name
        )
135 136 137 138 139
        num_of_ranks = functools.reduce(
            lambda x, y: x * y, self._list_of_degree
        )
        assert (
            num_of_ranks == dist.get_world_size()
L
LiYuRio 已提交
140
        ), "There are total {} ranks, but need {} ranks in this strategy.".format(
141 142
            dist.get_world_size(), num_of_ranks
        )
L
LiYuRio 已提交
143 144
        for fused_strategy in self._fused_strategy_dict.values():
            for strategy in fused_strategy:
145 146 147 148 149
                assert (
                    strategy in self._list_of_strategy_name
                ), "Can not fuse strategy {} without defined previous.".format(
                    strategy
                )
L
LiYuRio 已提交
150 151 152 153 154

    def _create_fused_group(self):
        for name in self._fused_strategy_dict:
            fused_strategy = self._fused_strategy_dict[name]
            non_fused_strategy = list(
155 156
                set(self._list_of_strategy_name).difference(fused_strategy)
            )
L
LiYuRio 已提交
157 158 159
            non_fused_ranges = []
            for strategy in non_fused_strategy:
                non_fused_ranges.append(
160 161
                    range(self._name_to_degree_dict[strategy])
                )
L
LiYuRio 已提交
162 163 164 165 166 167 168 169 170 171 172 173 174
            fused_ranges = []
            for strategy in fused_strategy:
                fused_ranges.append(range(self._name_to_degree_dict[strategy]))

            rank_list = []
            for non_fused_ranks in itertools.product(*non_fused_ranges):
                coord_dict = {}
                ranks = []
                for i, non_fused_rank in enumerate(non_fused_ranks):
                    coord_dict[non_fused_strategy[i]] = non_fused_rank
                for fused_ranks in itertools.product(*fused_ranges):
                    for i, fused_rank in enumerate(fused_ranks):
                        coord_dict[fused_strategy[i]] = fused_rank
175 176 177
                    ranks.append(
                        self._coord_to_rank_dict[self._coordinate(**coord_dict)]
                    )
L
LiYuRio 已提交
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
                rank_list.append(ranks)
            self._name_to_fused_group_dict[name] = StrategyGroupBase(rank_list)

    def _calc_rank_list(self, strategy_axis):
        ranges = []
        for idx, degree in enumerate(self._list_of_degree):
            if idx == strategy_axis:
                continue
            ranges.append(range(degree))

        rank_list = []
        for coord in itertools.product(*ranges):
            ranks = []
            for val in range(self._list_of_degree[strategy_axis]):
                coord_list = list(coord)
                coord_list.insert(strategy_axis, val)
                ranks.append(
195 196
                    self._coord_to_rank_dict[self._coordinate(*coord_list)]
                )
L
LiYuRio 已提交
197 198 199
            rank_list.append(ranks)

        return rank_list