# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools import collections import functools import paddle.distributed as dist from paddle.distributed.fleet.base.strategy_group import StrategyGroupBase class OrthogonalStrategy(): """ A hybrid of multiple distributed strategies. Strategies need to be orthogonal, means the ranks are organized like a square if there are two strategies, a cube if there aree three strategies, etc. Args: list_of_strategy(list): Stategy in the list should be represented as tuple, format as (strategy_name, degree, strategy_class). fused_strategy_dict(dict, optional): Exist strategies can be fused to new strategy. Use the name of new strategy as key, a list of strategy names you want to fuse as value. Returns: The instance of strategy. Examples: .. code-block:: python # required: distributed import paddle import paddle.distributed as dist from paddle.distributed.fleet.base.strategy_group import DPGroup, MPGroup, PPGroup from paddle.distributed.fleet.base.orthogonal_strategy import OrthogonalStrategy dist.init_parallel_env() strategy = OrthogonalStrategy([("dp", 2, DPGroup), ("mp", 2, MPGroup), ("pp", 2, PPGroup)], fused_strategy_dict={"check": ["mp", "pp"]}) """ def __init__(self, list_of_strategy, fused_strategy_dict={}): self._list_of_strategy = list_of_strategy self._fused_strategy_dict = fused_strategy_dict self._rank = dist.get_rank() self._rank_list_dict = {} self._name_to_group_dict = {} self._name_to_degree_dict = {} self._list_of_strategy_name = [ strategy[0] for strategy in list_of_strategy ] self._list_of_degree = [strategy[1] for strategy in list_of_strategy] self._coordinate = collections.namedtuple('Coordinate', self._list_of_strategy_name) self._check_valid_strategy() ranges = [range(degree) for degree in self._list_of_degree] list_of_coord = [ self._coordinate(*coord) for coord in itertools.product(*ranges) ] self._coord_to_rank_dict = dict( zip(list_of_coord, range(len(list_of_coord)))) for idx, strategy in enumerate(list_of_strategy): strategy_name = strategy[0] self._name_to_degree_dict[strategy_name] = strategy[1] self._rank_list_dict[strategy_name] = self._calc_rank_list(idx) self._name_to_group_dict[strategy_name] = strategy[2]( self._rank_list_dict[strategy_name]) self._name_to_fused_group_dict = {} self._create_fused_group() def strategy_group(self, name): """ Get strategy group with specific name. Args: name: The name of strategy group Returns: An instance of specific strategy group. """ assert name in self._list_of_strategy_name, "Strategy group {} is not created.".format( name) return self._name_to_group_dict[name] def fused_strategy_group(self, name): """ Get fused strategy group with specific name. Args: name: The name of fused strategy group Returns: (StrategyGroupBase): An instance of strategy group. """ assert name in self._name_to_fused_group_dict, "Fused strategy group {} is not created.".format( name) return self._name_to_fused_group_dict[name] def rank_in_strategy(self, name): """ Get local rank in strategy group with specific name. Args: name: The name of strategy group Returns: (Integer): Local rank in specific strategy. """ assert name in self._list_of_strategy_name, "Strategy group {} is not created.".format( name) return self._name_to_group_dict[name].group.rank def _check_valid_strategy(self): assert len(self._list_of_strategy_name) == len( set(self._list_of_strategy_name) ), "Defined duplicated strategies: {}".format(list_of_strategy) num_of_ranks = functools.reduce(lambda x, y: x * y, self._list_of_degree) assert num_of_ranks == dist.get_world_size( ), "There are total {} ranks, but need {} ranks in this strategy.".format( dist.get_world_size(), num_of_ranks) for fused_strategy in self._fused_strategy_dict.values(): for strategy in fused_strategy: assert strategy in self._list_of_strategy_name, "Can not fuse strategy {} without defined previous.".format( strategy) def _create_fused_group(self): for name in self._fused_strategy_dict: fused_strategy = self._fused_strategy_dict[name] non_fused_strategy = list( set(self._list_of_strategy_name).difference(fused_strategy)) non_fused_ranges = [] for strategy in non_fused_strategy: non_fused_ranges.append( range(self._name_to_degree_dict[strategy])) fused_ranges = [] for strategy in fused_strategy: fused_ranges.append(range(self._name_to_degree_dict[strategy])) rank_list = [] for non_fused_ranks in itertools.product(*non_fused_ranges): coord_dict = {} ranks = [] for i, non_fused_rank in enumerate(non_fused_ranks): coord_dict[non_fused_strategy[i]] = non_fused_rank for fused_ranks in itertools.product(*fused_ranges): for i, fused_rank in enumerate(fused_ranks): coord_dict[fused_strategy[i]] = fused_rank ranks.append(self._coord_to_rank_dict[self._coordinate( **coord_dict)]) rank_list.append(ranks) self._name_to_fused_group_dict[name] = StrategyGroupBase(rank_list) def _calc_rank_list(self, strategy_axis): ranges = [] for idx, degree in enumerate(self._list_of_degree): if idx == strategy_axis: continue ranges.append(range(degree)) rank_list = [] for coord in itertools.product(*ranges): ranks = [] for val in range(self._list_of_degree[strategy_axis]): coord_list = list(coord) coord_list.insert(strategy_axis, val) ranks.append( self._coord_to_rank_dict[self._coordinate(*coord_list)]) rank_list.append(ranks) return rank_list