未验证 提交 debae94e 编写于 作者: L lilong12 提交者: GitHub

update, test=develop (#33537)

上级 16099abf
...@@ -429,31 +429,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -429,31 +429,6 @@ class ShardingOptimizer(MetaOptimizerBase):
# pp ring # pp ring
if self.pp_degree > 1: if self.pp_degree > 1:
if self.schedule_mode == 'F-then-B': # GPipe
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
self.pp_group_endpoints,
self.pp_rank,
self.pp_ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
# append_naive_sync(startup_block, self.startup_prog_sync_var,
# self.global_ring_id)
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
self.pp_group_endpoints,
self.pp_rank,
self.pp_ring_id + 2,
False,
global_ring_id=self.global_ring_id,
sync=False)
# append_naive_sync(startup_block, self.startup_prog_sync_var,
# self.global_ring_id)
else:
assert self.schedule_mode == '1F1B'
for pair in self.pipeline_pair: for pair in self.pipeline_pair:
pair_key = pair[0] * 1000 + pair[1] pair_key = pair[0] * 1000 + pair[1]
ring_id = self.pp_ring_map[pair_key] ring_id = self.pp_ring_map[pair_key]
...@@ -466,8 +441,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -466,8 +441,7 @@ class ShardingOptimizer(MetaOptimizerBase):
if pair[0] < pair[1]: if pair[0] < pair[1]:
start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1 start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1
else: else:
start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[ start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[1] - 1
1] - 1
pp_rank = 0 if self.pp_rank == pair[0] else 1 pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self._startup_program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册