提交 a97b9df0 编写于 作者: S sandyhouse

update

上级 5294e51c
......@@ -66,14 +66,20 @@ class CollectiveHelper(object):
self.role_maker._worker_index(), ring_id, self.wait_port)
self._broadcast_params()
def _init_communicator(self, program, current_endpoint, endpoints, rank,
ring_id, wait_port):
def _init_communicator(self,
program,
current_endpoint,
endpoints,
rank,
ring_id,
wait_port,
sync=True):
nranks = len(endpoints)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
block = program.global_block()
if core.is_compiled_with_cuda():
if not wait_port:
if not wait_port and sync:
temp_var = block.create_var(
name=unique_name.generate('temp_var'),
dtype=core.VarDesc.VarType.INT32,
......
......@@ -96,6 +96,8 @@ class ShardingOptimizer(MetaOptimizerBase):
"use_pipeline"]
self.acc_steps = self.user_defined_strategy.sharding_configs[
"acc_steps"]
self.schedule_mode = self.user_defined_strategy.sharding_configs[
"schedule_mode"]
if self.inner_opt is None:
raise ValueError(
......@@ -105,6 +107,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self.acc_steps)
main_program = loss.block.program
main_program._pipeline_opt = dict()
main_program._pipeline_opt['schedule_mode'] = self.schedule_mode
pp_rank = self.role_maker._worker_index() // (
self.user_defined_strategy.sharding_configs[
'sharding_group_size'] * self._inner_parallelism_size)
......@@ -409,20 +412,33 @@ class ShardingOptimizer(MetaOptimizerBase):
print("pp_group_endpoints:", self.pp_group_endpoints)
print("pp_rank:", self.pp_rank)
print("pp_ring_id:", self.pp_ring_id)
for pair in self.pipeline_pair:
if self.pp_rank not in pair: continue
pp_group_endpoints = [
self.pp_group_endpoints[pair[0]],
self.pp_group_endpoints[pair[1]],
]
if pair[0] < pair[1]:
start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1
else:
start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[1] - 1
pp_rank = 0 if self.pp_rank == pair[0] else 1
if self.schedule_mode == 0: # GPipe
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.pp_group_endpoints, self.pp_rank, self.pp_ring_id,
False)
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
pp_group_endpoints, pp_rank, start_ring_id, False)
self.pp_group_endpoints, self.pp_rank, self.pp_ring_id + 2,
False)
else:
for pair in self.pipeline_pair:
print("pp pair:{}".format(pair))
if self.pp_rank not in pair: continue
pp_group_endpoints = [
self.pp_group_endpoints[pair[0]],
self.pp_group_endpoints[pair[1]],
]
if pair[0] < pair[1]:
start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1
else:
start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[
1] - 1
pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
pp_group_endpoints, pp_rank, start_ring_id, False,
False)
startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp()
......
......@@ -415,6 +415,7 @@ class Section(DeviceWorker):
section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"]
section_param.pipeline_stage = pipeline_opt["pipeline_stage"]
section_param.num_pipeline_stages = pipeline_opt["num_pipeline_stages"]
section_param.schedule_mode = pipeline_opt["schedule_mode"]
cfg = section_param.section_config
program = pipeline_opt["section_program"]
cfg.program_desc.ParseFromString(program["program"]._get_desc()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册