提交 f874e02b 编写于 作者: S sandyhouse

update optimizer

上级 d2c81529
...@@ -31,6 +31,8 @@ __all__ = ["ShardingOptimizer"] ...@@ -31,6 +31,8 @@ __all__ = ["ShardingOptimizer"]
class ShardingOptimizer(MetaOptimizerBase): class ShardingOptimizer(MetaOptimizerBase):
"""Sharding Optimizer."""
def __init__(self, optimizer): def __init__(self, optimizer):
super(ShardingOptimizer, self).__init__(optimizer) super(ShardingOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer self.inner_opt = optimizer
...@@ -77,6 +79,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -77,6 +79,7 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_program=None, startup_program=None,
parameter_list=None, parameter_list=None,
no_grad_set=None): no_grad_set=None):
"""Implementation of minimize."""
# TODO: (JZ-LIANG) support multiple comm in future # TODO: (JZ-LIANG) support multiple comm in future
# self._nrings = self.user_defined_strategy.nccl_comm_num # self._nrings = self.user_defined_strategy.nccl_comm_num
self._nrings_sharding = 1 self._nrings_sharding = 1
...@@ -91,12 +94,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -91,12 +94,15 @@ class ShardingOptimizer(MetaOptimizerBase):
self.user_defined_strategy.sharding_configs["parallelism"]) self.user_defined_strategy.sharding_configs["parallelism"])
self.use_pipeline = self.user_defined_strategy.sharding_configs[ self.use_pipeline = self.user_defined_strategy.sharding_configs[
"use_pipeline"] "use_pipeline"]
self.acc_steps = self.user_defined_strategy.sharding_configs[
"acc_steps"]
if self.inner_opt is None: if self.inner_opt is None:
raise ValueError( raise ValueError(
"self.inner_opt of ShardingOptimizer should not be None.") "self.inner_opt of ShardingOptimizer should not be None.")
if self.use_pipeline: if self.use_pipeline:
pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt) pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt,
self.acc_steps)
main_program = loss.block.program main_program = loss.block.program
main_program._pipeline_opt = dict() main_program._pipeline_opt = dict()
pp_rank = self.role_maker._worker_index() // ( pp_rank = self.role_maker._worker_index() // (
...@@ -107,7 +113,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -107,7 +113,7 @@ class ShardingOptimizer(MetaOptimizerBase):
'global_rank'] = self.role_maker._worker_index() 'global_rank'] = self.role_maker._worker_index()
main_program._pipeline_opt['use_sharding'] = True main_program._pipeline_opt['use_sharding'] = True
main_program._pipeline_opt['ring_id'] = 2 main_program._pipeline_opt['ring_id'] = 2
optimize_ops, params_grads, program_list = pp_optimizer.minimize( optimize_ops, params_grads, program_list, self.pipeline_pair = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set) loss, startup_program, parameter_list, no_grad_set)
self.pipeline_nodes = len(program_list) self.pipeline_nodes = len(program_list)
else: else:
...@@ -349,8 +355,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -349,8 +355,8 @@ class ShardingOptimizer(MetaOptimizerBase):
# check op dependecy # check op dependecy
check_broadcast(main_block) check_broadcast(main_block)
check_allreduce_sum(main_block, self._shard, self.sharding_ring_id, #check_allreduce_sum(main_block, self._shard, self.sharding_ring_id,
self.dp_ring_id) # self.dp_ring_id)
#check_allreduce_sum(main_block, self._shard, self.dp_ring_id) #check_allreduce_sum(main_block, self._shard, self.dp_ring_id)
self._wait() self._wait()
return optimize_ops, params_grads return optimize_ops, params_grads
...@@ -403,9 +409,20 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -403,9 +409,20 @@ class ShardingOptimizer(MetaOptimizerBase):
print("pp_group_endpoints:", self.pp_group_endpoints) print("pp_group_endpoints:", self.pp_group_endpoints)
print("pp_rank:", self.pp_rank) print("pp_rank:", self.pp_rank)
print("pp_ring_id:", self.pp_ring_id) print("pp_ring_id:", self.pp_ring_id)
for pair in self.pipeline_pair:
if self.pp_rank not in pair: continue
pp_group_endpoints = [
self.pp_group_endpoints[pair[0]],
self.pp_group_endpoints[pair[1]],
]
if pair[0] < pair[1]:
start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1
else:
start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[1] - 1
pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint, self._startup_program, self.current_endpoint,
self.pp_group_endpoints, self.pp_rank, self.pp_ring_id, False) pp_group_endpoints, pp_rank, start_ring_id, False)
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
......
...@@ -413,6 +413,8 @@ class Section(DeviceWorker): ...@@ -413,6 +413,8 @@ class Section(DeviceWorker):
section_param = trainer_desc.section_param section_param = trainer_desc.section_param
section_param.num_microbatches = pipeline_opt["num_microbatches"] section_param.num_microbatches = pipeline_opt["num_microbatches"]
section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"] section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"]
section_param.pipeline_stage = pipeline_opt["pipeline_stage"]
section_param.num_pipeline_stages = pipeline_opt["num_pipeline_stages"]
cfg = section_param.section_config cfg = section_param.section_config
program = pipeline_opt["section_program"] program = pipeline_opt["section_program"]
cfg.program_desc.ParseFromString(program["program"]._get_desc() cfg.program_desc.ParseFromString(program["program"]._get_desc()
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册