提交 997651ab 编写于 作者: S sandyhouse

update, test=develop

上级 d3105dbf
...@@ -99,8 +99,8 @@ void SectionWorker::TrainFiles() { ...@@ -99,8 +99,8 @@ void SectionWorker::TrainFiles() {
VLOG(3) << "Update: running op " << op->Type(); VLOG(3) << "Update: running op " << op->Type();
op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_); op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_);
if (gc) { if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[0], op.get(), unused_vars_, DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
gc.get()); op.get(), unused_vars_, gc.get());
} }
} }
} }
......
...@@ -40,6 +40,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -40,6 +40,7 @@ class ShardingOptimizer(MetaOptimizerBase):
"LarsOptimizer", "LarsOptimizer",
"LambOptimizer", "LambOptimizer",
"ModelParallelOptimizer", "ModelParallelOptimizer",
"PipelineOptimizer",
] ]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
self._main_program = None self._main_program = None
...@@ -98,14 +99,14 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -98,14 +99,14 @@ class ShardingOptimizer(MetaOptimizerBase):
pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt) pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt)
main_program = loss.block.program main_program = loss.block.program
main_program._pipeline_opt = dict() main_program._pipeline_opt = dict()
pp_rank = self.role_maker._worker_index( pp_rank = self.role_maker._worker_index() // (
) // self.user_defined_strategy.sharding_configs[ self.user_defined_strategy.sharding_configs[
'sharding_group_size'] 'sharding_group_size'] * self._inner_parallelism_size)
main_program._pipeline_opt['local_rank'] = pp_rank main_program._pipeline_opt['local_rank'] = pp_rank
main_program._pipeline_opt[ main_program._pipeline_opt[
'global_rank'] = self.role_maker._worker_index() 'global_rank'] = self.role_maker._worker_index()
main_program._pipeline_opt['use_sharding'] = True main_program._pipeline_opt['use_sharding'] = True
main_program._pipeline_opt['ring_id'] = 1 main_program._pipeline_opt['ring_id'] = 2
optimize_ops, params_grads, program_list = pp_optimizer.minimize( optimize_ops, params_grads, program_list = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set) loss, startup_program, parameter_list, no_grad_set)
self.pipeline_nodes = len(program_list) self.pipeline_nodes = len(program_list)
...@@ -358,16 +359,19 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -358,16 +359,19 @@ class ShardingOptimizer(MetaOptimizerBase):
# config sharding & dp groups # config sharding & dp groups
self._init_comm() self._init_comm()
# sharding # sharding
print("sharding_group_endpoints:", self.sharding_group_endpoints)
print("sharding_rank:", self.sharding_rank)
print("sharding_ring_id:", self.sharding_ring_id)
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint, self._startup_program, self.current_endpoint,
self.sharding_group_endpoints, self.sharding_rank, self.sharding_group_endpoints, self.sharding_rank,
self.sharding_ring_id, True) self.sharding_ring_id, True)
# inner & outer model parallelism # inner & outer model parallelism
if self._as_outer_parallelism: # if self._as_outer_parallelism:
self._collective_helper._init_communicator( # self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint, # self._startup_program, self.current_endpoint,
self.mp_group_endpoints, self.mp_rank, self.mp_group_id, True) # self.mp_group_endpoints, self.mp_rank, self.mp_group_id, True)
# dp # dp
if self.hybrid_dp: if self.hybrid_dp:
...@@ -757,7 +761,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -757,7 +761,7 @@ class ShardingOptimizer(MetaOptimizerBase):
logging.info("Using Sharing&DP mode !") logging.info("Using Sharing&DP mode !")
else: else:
if self._as_outer_parallelism: if self._as_outer_parallelism and not self.use_pipeline:
self.sharding_ring_id = 1 self.sharding_ring_id = 1
assert self.global_word_size > self._inner_parallelism_size, \ assert self.global_word_size > self._inner_parallelism_size, \
"global_word_size: {} should be larger than inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size) "global_word_size: {} should be larger than inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size)
...@@ -801,75 +805,113 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -801,75 +805,113 @@ class ShardingOptimizer(MetaOptimizerBase):
# logging.info("megatron endpoints: {}".format( # logging.info("megatron endpoints: {}".format(
# magetron_endpoints)) # magetron_endpoints))
if self.use_pipeline: if self.use_pipeline:
self.sharding_ring_id = 0 if self._inner_parallelism_size == 1:
self.sharding_group_size = self.user_defined_strategy.sharding_configs[ self.sharding_ring_id = 0
'sharding_group_size'] self.sharding_group_size = self.user_defined_strategy.sharding_configs[
self.sharding_rank = self.global_rank % self.sharding_group_size 'sharding_group_size']
assert self.sharding_group_size * self.pipeline_nodes == self.role_maker._worker_num( self.sharding_rank = self.global_rank % self.sharding_group_size
) assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num(
self.pp_ring_id = 1 )
self.pp_rank = self.global_rank // self.sharding_group_size self.pp_ring_id = 2
self.sharding_group_endpoints = [ self.pp_rank = self.global_rank // (
ep for idx, ep in enumerate(self.endpoints) self.sharding_group_size * self._inner_parallelism_size)
if (idx // self.sharding_group_size) == self.pp_rank self.sharding_group_endpoints = [
] ep for idx, ep in enumerate(self.endpoints)
self.pp_group_size = self.pipeline_nodes if (idx // self.sharding_group_size) == self.pp_rank
self.pp_group_endpoints = [ ]
ep for idx, ep in enumerate(self.endpoints) self.pp_group_size = self.pipeline_nodes
if (idx % self.sharding_group_size) == self.sharding_rank self.pp_group_endpoints = [
] ep for idx, ep in enumerate(self.endpoints)
if (idx % self.sharding_group_size
) == self.sharding_rank
]
else:
self.sharding_ring_id = 1
self.pp_ring_id = 2
# self.cards_per_node = 8
self.sharding_group_size = self.user_defined_strategy.sharding_configs[
'sharding_group_size']
self.sharding_rank = self.global_rank // self._inner_parallelism_size % self.sharding_group_size
# self.sharding_group_id = self.global_rank // (self._inner_parallelism_size % self.sharding_group_size)
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx // self._inner_parallelism_size %
self.sharding_group_size) == self.sharding_rank
]
assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num(
)
self.pp_rank = self.global_rank // (
self.sharding_group_size * self._inner_parallelism_size)
offset = self.sharding_group_size * self._inner_parallelism_size
idx_with_pp_0 = self.global_rank % (
self.sharding_group_size * self._inner_parallelism_size)
self.pp_group_endpoints = []
for i in range(self.pipeline_nodes):
self.pp_group_endpoints.append(self.endpoints[
idx_with_pp_0])
idx_with_pp_0 += offset
#self.pp_group_endpoints = [
# ep for idx, ep in enumerate(self.endpoints)
# if (idx % self.sharding_group_size) == self.sharding_rank
#]
self.mp_group_id = 1
self.mp_rank = self.global_rank
self.mp_group_size = self.role_maker._worker_num()
self.mp_group_endpoints = self.endpoints[:]
logging.info("Using Sharing as Outer parallelism mode !")
self.dp_ring_id = -1 self.dp_ring_id = -1
self.dp_rank = -1 self.dp_rank = -1
self.dp_group_size = None self.dp_group_size = None
self.dp_group_endpoints = None self.dp_group_endpoints = None
logging.info("Using Sharing with pipeline !") logging.info("Using Sharing with pipeline !")
else: #else:
self.sharding_ring_id = 0 # self.sharding_ring_id = 0
self.sharding_rank = self.global_rank # self.sharding_rank = self.global_rank
self.sharding_group_size = self.role_maker._worker_num() # self.sharding_group_size = self.role_maker._worker_num()
self.sharding_group_endpoints = self.endpoints # self.sharding_group_endpoints = self.endpoints
# sharding parallelism is the only model parallelism in the current setting # # sharding parallelism is the only model parallelism in the current setting
self.mp_group_id = self.sharding_ring_id # self.mp_group_id = self.sharding_ring_id
self.mp_rank = self.sharding_rank # self.mp_rank = self.sharding_rank
self.mp_group_size = self.sharding_group_size # self.mp_group_size = self.sharding_group_size
self.mp_group_endpoints = self.sharding_group_endpoints[:] # self.mp_group_endpoints = self.sharding_group_endpoints[:]
logging.info("Using Sharing alone mode !") # logging.info("Using Sharing alone mode !")
self.dp_ring_id = -1 self.dp_ring_id = -1
self.dp_rank = -1 self.dp_rank = -1
self.dp_group_size = None self.dp_group_size = None
self.dp_group_endpoints = None self.dp_group_endpoints = None
self.pp_ring_id = -1 #self.pp_ring_id = -1
self.pp_rank = -1 #self.pp_rank = -1
self.pp_group_size = None #self.pp_group_size = None
self.pp_group_endpoints = None #self.pp_group_endpoints = None
self.dp_ring_id = -1 #self.dp_ring_id = -1
self.dp_rank = -1 #self.dp_rank = -1
self.dp_group_size = None #self.dp_group_size = None
self.dp_group_endpoints = None #self.dp_group_endpoints = None
logging.info("Using Sharing alone mode !") logging.info("Using Sharing alone mode !")
logging.info("global word size: {}".format(self.global_word_size)) #logging.info("global word size: {}".format(self.global_word_size))
logging.info("global rank: {}".format(self.global_rank)) #logging.info("global rank: {}".format(self.global_rank))
logging.info("sharding group_size: {}".format(self.sharding_group_size)) #logging.info("sharding group_size: {}".format(self.sharding_group_size))
logging.info("sharding rank: {}".format(self.sharding_rank)) #logging.info("sharding rank: {}".format(self.sharding_rank))
logging.info("current model parallelism group_size: {}".format( #logging.info("current model parallelism group_size: {}".format(
self.mp_group_size)) # self.mp_group_size))
logging.info("current model parallelism rank: {}".format(self.mp_rank)) #logging.info("current model parallelism rank: {}".format(self.mp_rank))
logging.info("dp group size: {}".format(self.dp_group_size)) #logging.info("dp group size: {}".format(self.dp_group_size))
logging.info("dp rank: {}".format(self.dp_rank)) #logging.info("dp rank: {}".format(self.dp_rank))
logging.info("current endpoint: {}".format(self.current_endpoint)) #logging.info("current endpoint: {}".format(self.current_endpoint))
logging.info("global word endpoints: {}".format(self.endpoints)) #logging.info("global word endpoints: {}".format(self.endpoints))
logging.info("sharding group endpoints: {}".format( #logging.info("sharding group endpoints: {}".format(
self.sharding_group_endpoints)) # self.sharding_group_endpoints))
logging.info("current model parallelism group endpoints: {}".format( #logging.info("current model parallelism group endpoints: {}".format(
self.mp_group_endpoints)) # self.mp_group_endpoints))
logging.info("dp group endpoints: {}".format(self.dp_group_endpoints)) #logging.info("dp group endpoints: {}".format(self.dp_group_endpoints))
return return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册