From f48611f38c4b45e832adaa688276f054d0619d8c Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Sun, 14 May 2023 14:12:35 +0800 Subject: [PATCH] [Cherry-Pick]Add identity hcg for hybridparallel (#53787) * add utest * rm hack code --- .../fleet/base/orthogonal_strategy.py | 26 +++++++---- .../distributed/fleet/base/strategy_group.py | 45 +++++++++++++++---- .../collective/orthogonal_strategy.py | 36 +++++++++++++++ 3 files changed, 90 insertions(+), 17 deletions(-) diff --git a/python/paddle/distributed/fleet/base/orthogonal_strategy.py b/python/paddle/distributed/fleet/base/orthogonal_strategy.py index e226b0de2d2..412e679db42 100644 --- a/python/paddle/distributed/fleet/base/orthogonal_strategy.py +++ b/python/paddle/distributed/fleet/base/orthogonal_strategy.py @@ -47,11 +47,16 @@ class OrthogonalStrategy: """ - def __init__(self, list_of_strategy, fused_strategy_dict={}): + def __init__( + self, list_of_strategy, fused_strategy_dict={}, strategy_rank_list=None + ): self._list_of_strategy = list_of_strategy self._fused_strategy_dict = fused_strategy_dict - self._rank = dist.get_rank() - self._rank_list_dict = {} + self._strategy_rank_list = ( + strategy_rank_list + if strategy_rank_list is not None + else list(range(dist.get_world_size())) + ) self._name_to_group_dict = {} self._name_to_degree_dict = {} self._list_of_strategy_name = [ @@ -67,16 +72,17 @@ class OrthogonalStrategy: list_of_coord = [ self._coordinate(*coord) for coord in itertools.product(*ranges) ] + self._coord_to_rank_dict = dict( - zip(list_of_coord, range(len(list_of_coord))) + zip(list_of_coord, self._strategy_rank_list) ) for idx, strategy in enumerate(list_of_strategy): strategy_name = strategy[0] self._name_to_degree_dict[strategy_name] = strategy[1] - self._rank_list_dict[strategy_name] = self._calc_rank_list(idx) + rank_list = self._calc_rank_list(idx) self._name_to_group_dict[strategy_name] = strategy[2]( - self._rank_list_dict[strategy_name] + rank_list, ) self._name_to_fused_group_dict = {} @@ -136,11 +142,13 @@ class OrthogonalStrategy: num_of_ranks = functools.reduce( lambda x, y: x * y, self._list_of_degree ) - assert ( - num_of_ranks == dist.get_world_size() + + assert num_of_ranks == len( + self._strategy_rank_list ), "There are total {} ranks, but need {} ranks in this strategy.".format( - dist.get_world_size(), num_of_ranks + len(self._strategy_rank_list), num_of_ranks ) + for fused_strategy in self._fused_strategy_dict.values(): for strategy in fused_strategy: assert ( diff --git a/python/paddle/distributed/fleet/base/strategy_group.py b/python/paddle/distributed/fleet/base/strategy_group.py index 96c56d80879..9c0159060ec 100644 --- a/python/paddle/distributed/fleet/base/strategy_group.py +++ b/python/paddle/distributed/fleet/base/strategy_group.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. + import paddle.distributed as dist +from paddle.distributed.fleet.layers.mpu import RNGStatesTracker class StrategyGroupBase: @@ -39,6 +41,9 @@ class StrategyGroupBase: """ def __init__(self, list_of_ranks): + """ + Initialize the communication group. + """ assert ( dist.is_initialized() ), "The global communication group need to be initialized." @@ -46,6 +51,19 @@ class StrategyGroupBase: self._rank = dist.get_rank() self._list_of_ranks = list_of_ranks self._group = self._create_group() + self.random_states_tracker = RNGStatesTracker() + + def add_random_seed(self, name, seed): + """ + Add random seed for current rank. + """ + self.random_states_tracker.add(name, seed) + + def get_random_states_tracker(self): + """ + Get the random states tracker. + """ + return self.random_states_tracker @property def world_size(self): @@ -74,17 +92,28 @@ class StrategyGroupBase: return self._group def _create_group(self): - list_of_group = [] + self.list_of_group = [] for ranks in self._list_of_ranks: group = dist.new_group(ranks=ranks) if self._rank in ranks: - list_of_group.append(group) - assert ( - len(list_of_group) > 0 - ), "Rank {} does not belong to the list_of_ranks {}.".format( - self._rank, self._list_of_ranks - ) - return list_of_group if len(list_of_group) > 1 else list_of_group[0] + self.list_of_group.append(group) + + if not self.list_of_group: + return None + else: + return ( + self.list_of_group[0] + if len(self.list_of_group) == 1 + else self.list_of_group + ) + + def __repr__(self): + debug_str = f"seed: {self._seed}; " + if not self.list_of_group: + return debug_str + "No group." + for i in range(len(self.list_of_group)): + debug_str += f"Group[{i}]: {str(self.list_of_group[i])}; " + return debug_str class DPGroup(StrategyGroupBase): diff --git a/python/paddle/fluid/tests/unittests/collective/orthogonal_strategy.py b/python/paddle/fluid/tests/unittests/collective/orthogonal_strategy.py index 02d92761ea9..90761f9fccd 100644 --- a/python/paddle/fluid/tests/unittests/collective/orthogonal_strategy.py +++ b/python/paddle/fluid/tests/unittests/collective/orthogonal_strategy.py @@ -14,6 +14,7 @@ import unittest +import paddle import paddle.distributed as dist from paddle.distributed.fleet.base.orthogonal_strategy import OrthogonalStrategy from paddle.distributed.fleet.base.strategy_group import ( @@ -52,5 +53,40 @@ class TestOrthogonalStrategyAPI(unittest.TestCase): self.assertEqual(fused_group.group.nranks, 1) +class TestOrthogonalStrategyCustomAPI(unittest.TestCase): + def setUp(self): + self._num_of_ranks = 2 + dist.init_parallel_env() + self._global_rank = dist.get_rank() + self._strategy = OrthogonalStrategy( + [ + ("dp", 1, DPGroup), + ("mp", 2, MPGroup), + ("sharding", 1, ShardingGroup), + ("pp", 1, PPGroup), + ], + fused_strategy_dict={"checkness": ["mp", "sharding", "pp"]}, + strategy_rank_list=[1, 0], + ) + + self._strategy.strategy_group("mp").add_random_seed("local_seed", 123) + self._strategy.strategy_group("mp").add_random_seed("global_seed", 321) + + def test_orthogonal_strategy(self): + mp_group = self._strategy.strategy_group("mp") + self.assertEqual(mp_group.world_size, self._num_of_ranks) + self.assertEqual(mp_group.group.nranks, self._num_of_ranks) + self.assertEqual( + self._strategy.rank_in_strategy("mp"), self._global_rank + ) + + fused_group = self._strategy.fused_strategy_group("checkness") + self.assertEqual(fused_group.world_size, 2) + self.assertEqual(fused_group.group.nranks, 2) + + with mp_group.random_states_tracker.rng_state("local_seed"): + a = paddle.randint(0, 100, [10]).numpy()[0] + + if __name__ == '__main__': unittest.main() -- GitLab