提交 f874e02b 编写于 作者: S sandyhouse

update optimizer

上级 d2c81529
......@@ -31,6 +31,8 @@ __all__ = ["ShardingOptimizer"]
class ShardingOptimizer(MetaOptimizerBase):
"""Sharding Optimizer."""
def __init__(self, optimizer):
super(ShardingOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
......@@ -77,6 +79,7 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_program=None,
parameter_list=None,
no_grad_set=None):
"""Implementation of minimize."""
# TODO: (JZ-LIANG) support multiple comm in future
# self._nrings = self.user_defined_strategy.nccl_comm_num
self._nrings_sharding = 1
......@@ -91,12 +94,15 @@ class ShardingOptimizer(MetaOptimizerBase):
self.user_defined_strategy.sharding_configs["parallelism"])
self.use_pipeline = self.user_defined_strategy.sharding_configs[
"use_pipeline"]
self.acc_steps = self.user_defined_strategy.sharding_configs[
"acc_steps"]
if self.inner_opt is None:
raise ValueError(
"self.inner_opt of ShardingOptimizer should not be None.")
if self.use_pipeline:
pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt)
pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt,
self.acc_steps)
main_program = loss.block.program
main_program._pipeline_opt = dict()
pp_rank = self.role_maker._worker_index() // (
......@@ -107,7 +113,7 @@ class ShardingOptimizer(MetaOptimizerBase):
'global_rank'] = self.role_maker._worker_index()
main_program._pipeline_opt['use_sharding'] = True
main_program._pipeline_opt['ring_id'] = 2
optimize_ops, params_grads, program_list = pp_optimizer.minimize(
optimize_ops, params_grads, program_list, self.pipeline_pair = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set)
self.pipeline_nodes = len(program_list)
else:
......@@ -349,8 +355,8 @@ class ShardingOptimizer(MetaOptimizerBase):
# check op dependecy
check_broadcast(main_block)
check_allreduce_sum(main_block, self._shard, self.sharding_ring_id,
self.dp_ring_id)
#check_allreduce_sum(main_block, self._shard, self.sharding_ring_id,
# self.dp_ring_id)
#check_allreduce_sum(main_block, self._shard, self.dp_ring_id)
self._wait()
return optimize_ops, params_grads
......@@ -403,9 +409,20 @@ class ShardingOptimizer(MetaOptimizerBase):
print("pp_group_endpoints:", self.pp_group_endpoints)
print("pp_rank:", self.pp_rank)
print("pp_ring_id:", self.pp_ring_id)
for pair in self.pipeline_pair:
if self.pp_rank not in pair: continue
pp_group_endpoints = [
self.pp_group_endpoints[pair[0]],
self.pp_group_endpoints[pair[1]],
]
if pair[0] < pair[1]:
start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1
else:
start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[1] - 1
pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.pp_group_endpoints, self.pp_rank, self.pp_ring_id, False)
pp_group_endpoints, pp_rank, start_ring_id, False)
startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp()
......
......@@ -413,6 +413,8 @@ class Section(DeviceWorker):
section_param = trainer_desc.section_param
section_param.num_microbatches = pipeline_opt["num_microbatches"]
section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"]
section_param.pipeline_stage = pipeline_opt["pipeline_stage"]
section_param.num_pipeline_stages = pipeline_opt["num_pipeline_stages"]
cfg = section_param.section_config
program = pipeline_opt["section_program"]
cfg.program_desc.ParseFromString(program["program"]._get_desc()
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册