未验证 提交 7be50f90 编写于 作者: L lilong12 提交者: GitHub

update, test=develop (#33588)

上级 172f2719
...@@ -428,59 +428,33 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -428,59 +428,33 @@ class ShardingOptimizer(MetaOptimizerBase):
# pp ring # pp ring
if self.pp_degree > 1: if self.pp_degree > 1:
if self.schedule_mode == 'F-then-B': # GPipe for pair in self.pipeline_pair:
self._collective_helper._init_communicator( pair_key = pair[0] * 1000 + pair[1]
self._startup_program, ring_id = self.pp_ring_map[pair_key]
self.current_endpoint, print("pp pair:{}, ring_id: {}".format(pair, ring_id))
self.pp_group_endpoints, if self.pp_rank not in pair: continue
self.pp_rank, pp_group_endpoints = [
self.pp_ring_id, self.pp_group_endpoints[pair[0]],
False, self.pp_group_endpoints[pair[1]],
global_ring_id=self.global_ring_id, ]
sync=False) if pair[0] < pair[1]:
# append_naive_sync(startup_block, self.startup_prog_sync_var, start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1
# self.global_ring_id) else:
start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[1] - 1
pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self._startup_program,
self.current_endpoint, self.current_endpoint,
self.pp_group_endpoints, pp_group_endpoints,
self.pp_rank, pp_rank,
self.pp_ring_id + 2, ring_id,
False, False,
global_ring_id=self.global_ring_id, global_ring_id=self.global_ring_id,
sync=False) sync=False)
# append_naive_sync(startup_block, self.startup_prog_sync_var, # append_naive_sync(startup_block, self.startup_prog_sync_var,
# self.global_ring_id) # self.global_ring_id)
else:
assert self.schedule_mode == '1F1B' # TODO (JZ-LIANG) to unify this shit
for pair in self.pipeline_pair:
pair_key = pair[0] * 1000 + pair[1]
ring_id = self.pp_ring_map[pair_key]
print("pp pair:{}, ring_id: {}".format(pair, ring_id))
if self.pp_rank not in pair: continue
pp_group_endpoints = [
self.pp_group_endpoints[pair[0]],
self.pp_group_endpoints[pair[1]],
]
if pair[0] < pair[1]:
start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1
else:
start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[
1] - 1
pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
pp_group_endpoints,
pp_rank,
ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
# append_naive_sync(startup_block, self.startup_prog_sync_var,
# self.global_ring_id)
# TODO (JZ-LIANG) to unify this shit
assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format( assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
self.pp_rank_, self.pp_rank) self.pp_rank_, self.pp_rank)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册