提交 174e25cf 编写于 作者: S sandyhouse

update

上级 997651ab
......@@ -39,7 +39,7 @@ class ShardingOptimizer(MetaOptimizerBase):
"AMPOptimizer",
"LarsOptimizer",
"LambOptimizer",
"ModelParallelOptimizer",
# "ModelParallelOptimizer",
"PipelineOptimizer",
]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
......@@ -358,6 +358,19 @@ class ShardingOptimizer(MetaOptimizerBase):
self._nrings_sharding)
# config sharding & dp groups
self._init_comm()
# inner & outer model parallelism
if self._as_outer_parallelism:
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.global_group_endpoints, self.global_rank,
self.global_group_id, True)
if self._as_outer_parallelism:
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.mp_group_endpoints, self.mp_rank, self.mp_group_id, False)
# sharding
print("sharding_group_endpoints:", self.sharding_group_endpoints)
print("sharding_rank:", self.sharding_rank)
......@@ -365,13 +378,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.sharding_group_endpoints, self.sharding_rank,
self.sharding_ring_id, True)
# inner & outer model parallelism
# if self._as_outer_parallelism:
# self._collective_helper._init_communicator(
# self._startup_program, self.current_endpoint,
# self.mp_group_endpoints, self.mp_rank, self.mp_group_id, True)
self.sharding_ring_id, False)
# dp
if self.hybrid_dp:
......@@ -382,7 +389,7 @@ class ShardingOptimizer(MetaOptimizerBase):
if self.use_pipeline:
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.pp_group_endpoints, self.pp_rank, self.pp_ring_id, True)
self.pp_group_endpoints, self.pp_rank, self.pp_ring_id, False)
startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp()
......@@ -482,7 +489,7 @@ class ShardingOptimizer(MetaOptimizerBase):
# group. and each Data Parallelism group should have its own sync of FoundInfinite
Model_Paramllelism_ring_id = self.sharding_ring_id
if self._as_outer_parallelism:
Model_Paramllelism_ring_id = self.mp_group_id
Model_Paramllelism_ring_id = self.global_group_id
FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param,
Model_Paramllelism_ring_id)
gradientclip_helper = GradientClipHelper(Model_Paramllelism_ring_id)
......@@ -826,23 +833,42 @@ class ShardingOptimizer(MetaOptimizerBase):
) == self.sharding_rank
]
else:
self.mp_group_id = 0
self.sharding_ring_id = 1
self.pp_ring_id = 2
self.mp_rank = self.global_rank % self._inner_parallelism_size
self.mp_group = self.global_rank // self._inner_parallelism_size
self.mp_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if idx // self._inner_parallelism_size == self.mp_group
]
print("megatron_group_endpoints:", self.mp_group_endpoints)
print("megatron_rank:", self.mp_rank)
# self.cards_per_node = 8
self.sharding_group_size = self.user_defined_strategy.sharding_configs[
'sharding_group_size']
self.sharding_rank = self.global_rank // self._inner_parallelism_size % self.sharding_group_size
# self.sharding_group_id = self.global_rank // (self._inner_parallelism_size % self.sharding_group_size)
self.sharding_rank = (
self.global_rank //
self._inner_parallelism_size) % self.sharding_group_size
self.sharding_group_id = self.global_rank // (
self._inner_parallelism_size * self.sharding_group_size)
self.megatron_rank = self.global_rank % self._inner_parallelism_size
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx // self._inner_parallelism_size %
self.sharding_group_size) == self.sharding_rank
if (idx // (self._inner_parallelism_size *
self.sharding_group_size)
) == self.sharding_group_id and idx %
self._inner_parallelism_size == self.megatron_rank
]
print("sharding_endpoint:", self.sharding_group_endpoints)
print("sharding_rank:", self.sharding_rank)
assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num(
)
self.pp_rank = self.global_rank // (
self.sharding_group_size * self._inner_parallelism_size)
self.sharding_group_size *
self._inner_parallelism_size) % self.pipeline_nodes
offset = self.sharding_group_size * self._inner_parallelism_size
# TODO: Adjust for dp
idx_with_pp_0 = self.global_rank % (
self.sharding_group_size * self._inner_parallelism_size)
self.pp_group_endpoints = []
......@@ -850,15 +876,17 @@ class ShardingOptimizer(MetaOptimizerBase):
self.pp_group_endpoints.append(self.endpoints[
idx_with_pp_0])
idx_with_pp_0 += offset
print("pp_group_endpoints:", self.pp_group_endpoints)
print("pp_rank:", self.pp_rank)
#self.pp_group_endpoints = [
# ep for idx, ep in enumerate(self.endpoints)
# if (idx % self.sharding_group_size) == self.sharding_rank
#]
self.mp_group_id = 1
self.mp_rank = self.global_rank
self.mp_group_size = self.role_maker._worker_num()
self.mp_group_endpoints = self.endpoints[:]
self.global_group_id = 3
self.global_rank = self.global_rank
self.global_group_size = self.role_maker._worker_num()
self.global_group_endpoints = self.endpoints[:]
logging.info("Using Sharing as Outer parallelism mode !")
self.dp_ring_id = -1
self.dp_rank = -1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册