提交 174e25cf 编写于 作者: S sandyhouse

update

上级 997651ab
...@@ -39,7 +39,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -39,7 +39,7 @@ class ShardingOptimizer(MetaOptimizerBase):
"AMPOptimizer", "AMPOptimizer",
"LarsOptimizer", "LarsOptimizer",
"LambOptimizer", "LambOptimizer",
"ModelParallelOptimizer", # "ModelParallelOptimizer",
"PipelineOptimizer", "PipelineOptimizer",
] ]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
...@@ -358,6 +358,19 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -358,6 +358,19 @@ class ShardingOptimizer(MetaOptimizerBase):
self._nrings_sharding) self._nrings_sharding)
# config sharding & dp groups # config sharding & dp groups
self._init_comm() self._init_comm()
# inner & outer model parallelism
if self._as_outer_parallelism:
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.global_group_endpoints, self.global_rank,
self.global_group_id, True)
if self._as_outer_parallelism:
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.mp_group_endpoints, self.mp_rank, self.mp_group_id, False)
# sharding # sharding
print("sharding_group_endpoints:", self.sharding_group_endpoints) print("sharding_group_endpoints:", self.sharding_group_endpoints)
print("sharding_rank:", self.sharding_rank) print("sharding_rank:", self.sharding_rank)
...@@ -365,13 +378,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -365,13 +378,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint, self._startup_program, self.current_endpoint,
self.sharding_group_endpoints, self.sharding_rank, self.sharding_group_endpoints, self.sharding_rank,
self.sharding_ring_id, True) self.sharding_ring_id, False)
# inner & outer model parallelism
# if self._as_outer_parallelism:
# self._collective_helper._init_communicator(
# self._startup_program, self.current_endpoint,
# self.mp_group_endpoints, self.mp_rank, self.mp_group_id, True)
# dp # dp
if self.hybrid_dp: if self.hybrid_dp:
...@@ -382,7 +389,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -382,7 +389,7 @@ class ShardingOptimizer(MetaOptimizerBase):
if self.use_pipeline: if self.use_pipeline:
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint, self._startup_program, self.current_endpoint,
self.pp_group_endpoints, self.pp_rank, self.pp_ring_id, True) self.pp_group_endpoints, self.pp_rank, self.pp_ring_id, False)
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
...@@ -482,7 +489,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -482,7 +489,7 @@ class ShardingOptimizer(MetaOptimizerBase):
# group. and each Data Parallelism group should have its own sync of FoundInfinite # group. and each Data Parallelism group should have its own sync of FoundInfinite
Model_Paramllelism_ring_id = self.sharding_ring_id Model_Paramllelism_ring_id = self.sharding_ring_id
if self._as_outer_parallelism: if self._as_outer_parallelism:
Model_Paramllelism_ring_id = self.mp_group_id Model_Paramllelism_ring_id = self.global_group_id
FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param, FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param,
Model_Paramllelism_ring_id) Model_Paramllelism_ring_id)
gradientclip_helper = GradientClipHelper(Model_Paramllelism_ring_id) gradientclip_helper = GradientClipHelper(Model_Paramllelism_ring_id)
...@@ -826,23 +833,42 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -826,23 +833,42 @@ class ShardingOptimizer(MetaOptimizerBase):
) == self.sharding_rank ) == self.sharding_rank
] ]
else: else:
self.mp_group_id = 0
self.sharding_ring_id = 1 self.sharding_ring_id = 1
self.pp_ring_id = 2 self.pp_ring_id = 2
self.mp_rank = self.global_rank % self._inner_parallelism_size
self.mp_group = self.global_rank // self._inner_parallelism_size
self.mp_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if idx // self._inner_parallelism_size == self.mp_group
]
print("megatron_group_endpoints:", self.mp_group_endpoints)
print("megatron_rank:", self.mp_rank)
# self.cards_per_node = 8 # self.cards_per_node = 8
self.sharding_group_size = self.user_defined_strategy.sharding_configs[ self.sharding_group_size = self.user_defined_strategy.sharding_configs[
'sharding_group_size'] 'sharding_group_size']
self.sharding_rank = self.global_rank // self._inner_parallelism_size % self.sharding_group_size self.sharding_rank = (
# self.sharding_group_id = self.global_rank // (self._inner_parallelism_size % self.sharding_group_size) self.global_rank //
self._inner_parallelism_size) % self.sharding_group_size
self.sharding_group_id = self.global_rank // (
self._inner_parallelism_size * self.sharding_group_size)
self.megatron_rank = self.global_rank % self._inner_parallelism_size
self.sharding_group_endpoints = [ self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints) ep for idx, ep in enumerate(self.endpoints)
if (idx // self._inner_parallelism_size % if (idx // (self._inner_parallelism_size *
self.sharding_group_size) == self.sharding_rank self.sharding_group_size)
) == self.sharding_group_id and idx %
self._inner_parallelism_size == self.megatron_rank
] ]
print("sharding_endpoint:", self.sharding_group_endpoints)
print("sharding_rank:", self.sharding_rank)
assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num( assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num(
) )
self.pp_rank = self.global_rank // ( self.pp_rank = self.global_rank // (
self.sharding_group_size * self._inner_parallelism_size) self.sharding_group_size *
self._inner_parallelism_size) % self.pipeline_nodes
offset = self.sharding_group_size * self._inner_parallelism_size offset = self.sharding_group_size * self._inner_parallelism_size
# TODO: Adjust for dp
idx_with_pp_0 = self.global_rank % ( idx_with_pp_0 = self.global_rank % (
self.sharding_group_size * self._inner_parallelism_size) self.sharding_group_size * self._inner_parallelism_size)
self.pp_group_endpoints = [] self.pp_group_endpoints = []
...@@ -850,15 +876,17 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -850,15 +876,17 @@ class ShardingOptimizer(MetaOptimizerBase):
self.pp_group_endpoints.append(self.endpoints[ self.pp_group_endpoints.append(self.endpoints[
idx_with_pp_0]) idx_with_pp_0])
idx_with_pp_0 += offset idx_with_pp_0 += offset
print("pp_group_endpoints:", self.pp_group_endpoints)
print("pp_rank:", self.pp_rank)
#self.pp_group_endpoints = [ #self.pp_group_endpoints = [
# ep for idx, ep in enumerate(self.endpoints) # ep for idx, ep in enumerate(self.endpoints)
# if (idx % self.sharding_group_size) == self.sharding_rank # if (idx % self.sharding_group_size) == self.sharding_rank
#] #]
self.mp_group_id = 1 self.global_group_id = 3
self.mp_rank = self.global_rank self.global_rank = self.global_rank
self.mp_group_size = self.role_maker._worker_num() self.global_group_size = self.role_maker._worker_num()
self.mp_group_endpoints = self.endpoints[:] self.global_group_endpoints = self.endpoints[:]
logging.info("Using Sharing as Outer parallelism mode !") logging.info("Using Sharing as Outer parallelism mode !")
self.dp_ring_id = -1 self.dp_ring_id = -1
self.dp_rank = -1 self.dp_rank = -1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册