未验证 提交 937e21a3 编写于 作者: Y Yuang Liu 提交者: GitHub

supports mp and dp hybrid (#34377)

上级 846be131
...@@ -269,11 +269,27 @@ class Fleet(object): ...@@ -269,11 +269,27 @@ class Fleet(object):
cg.set_comm_group('global', global_rank, global_world_size, cg.set_comm_group('global', global_rank, global_world_size,
global_ring_id, global_ranks) global_ring_id, global_ranks)
use_tensor_parallel = self._user_defined_strategy.tensor_parallel
use_mp = use_sharding or use_tensor_parallel
# hybrid group # hybrid group
if use_sharding is False: return if use_mp is False: return
mp_degree_sharding = 1
mp_degree_tensor_parallel = 1
if use_sharding:
sharding_configs = self._user_defined_strategy.sharding_configs
mp_degree_sharding = int(sharding_configs['mp_degree'])
if use_tensor_parallel:
tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs
mp_degree_tensor_parallel = int(tensor_parallel_configs[
'tensor_parallel_degree'])
if use_sharding and use_tensor_parallel:
assert mp_degree_sharding == mp_degree_tensor_parallel
sharding_configs = self._user_defined_strategy.sharding_configs mp_degree = mp_degree_sharding if use_sharding else mp_degree_tensor_parallel
mp_degree = int(sharding_configs['mp_degree'])
if mp_degree > 1: if mp_degree > 1:
assert global_world_size % mp_degree == 0 assert global_world_size % mp_degree == 0
......
...@@ -84,6 +84,8 @@ class TestDistTraning(unittest.TestCase): ...@@ -84,6 +84,8 @@ class TestDistTraning(unittest.TestCase):
"mp_degree": self.model_parallel_size, "mp_degree": self.model_parallel_size,
"sharding_degree": 2, "sharding_degree": 2,
} }
strategy.tensor_parallel = True
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 2}
fleet.init(is_collective=True, strategy=strategy) fleet.init(is_collective=True, strategy=strategy)
def get_program(self): def get_program(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册