From 937e21a3b6aa2a794d4f05b2a65b44317bdbf1e6 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Tue, 27 Jul 2021 10:01:24 +0800 Subject: [PATCH] supports mp and dp hybrid (#34377) --- .../distributed/fleet/base/fleet_base.py | 22 ++++++++++++++++--- .../unittests/test_fleet_static_mp_layers.py | 2 ++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 2a9b15c7325..d0020a2776b 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -269,11 +269,27 @@ class Fleet(object): cg.set_comm_group('global', global_rank, global_world_size, global_ring_id, global_ranks) + use_tensor_parallel = self._user_defined_strategy.tensor_parallel + use_mp = use_sharding or use_tensor_parallel + # hybrid group - if use_sharding is False: return + if use_mp is False: return + + mp_degree_sharding = 1 + mp_degree_tensor_parallel = 1 + if use_sharding: + sharding_configs = self._user_defined_strategy.sharding_configs + mp_degree_sharding = int(sharding_configs['mp_degree']) + + if use_tensor_parallel: + tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs + mp_degree_tensor_parallel = int(tensor_parallel_configs[ + 'tensor_parallel_degree']) + + if use_sharding and use_tensor_parallel: + assert mp_degree_sharding == mp_degree_tensor_parallel - sharding_configs = self._user_defined_strategy.sharding_configs - mp_degree = int(sharding_configs['mp_degree']) + mp_degree = mp_degree_sharding if use_sharding else mp_degree_tensor_parallel if mp_degree > 1: assert global_world_size % mp_degree == 0 diff --git a/python/paddle/fluid/tests/unittests/test_fleet_static_mp_layers.py b/python/paddle/fluid/tests/unittests/test_fleet_static_mp_layers.py index 6c7fab25a30..c9de3814f0a 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_static_mp_layers.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_static_mp_layers.py @@ -84,6 +84,8 @@ class TestDistTraning(unittest.TestCase): "mp_degree": self.model_parallel_size, "sharding_degree": 2, } + strategy.tensor_parallel = True + strategy.tensor_parallel_configs = {"tensor_parallel_degree": 2} fleet.init(is_collective=True, strategy=strategy) def get_program(self): -- GitLab