提交 997651ab 编写于 作者: S sandyhouse

update, test=develop

上级 d3105dbf
......@@ -99,8 +99,8 @@ void SectionWorker::TrainFiles() {
VLOG(3) << "Update: running op " << op->Type();
op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[0], op.get(), unused_vars_,
gc.get());
DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
op.get(), unused_vars_, gc.get());
}
}
}
......
......@@ -40,6 +40,7 @@ class ShardingOptimizer(MetaOptimizerBase):
"LarsOptimizer",
"LambOptimizer",
"ModelParallelOptimizer",
"PipelineOptimizer",
]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
self._main_program = None
......@@ -98,14 +99,14 @@ class ShardingOptimizer(MetaOptimizerBase):
pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt)
main_program = loss.block.program
main_program._pipeline_opt = dict()
pp_rank = self.role_maker._worker_index(
) // self.user_defined_strategy.sharding_configs[
'sharding_group_size']
pp_rank = self.role_maker._worker_index() // (
self.user_defined_strategy.sharding_configs[
'sharding_group_size'] * self._inner_parallelism_size)
main_program._pipeline_opt['local_rank'] = pp_rank
main_program._pipeline_opt[
'global_rank'] = self.role_maker._worker_index()
main_program._pipeline_opt['use_sharding'] = True
main_program._pipeline_opt['ring_id'] = 1
main_program._pipeline_opt['ring_id'] = 2
optimize_ops, params_grads, program_list = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set)
self.pipeline_nodes = len(program_list)
......@@ -358,16 +359,19 @@ class ShardingOptimizer(MetaOptimizerBase):
# config sharding & dp groups
self._init_comm()
# sharding
print("sharding_group_endpoints:", self.sharding_group_endpoints)
print("sharding_rank:", self.sharding_rank)
print("sharding_ring_id:", self.sharding_ring_id)
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.sharding_group_endpoints, self.sharding_rank,
self.sharding_ring_id, True)
# inner & outer model parallelism
if self._as_outer_parallelism:
self._collective_helper._init_communicator(
self._startup_program, self.current_endpoint,
self.mp_group_endpoints, self.mp_rank, self.mp_group_id, True)
# if self._as_outer_parallelism:
# self._collective_helper._init_communicator(
# self._startup_program, self.current_endpoint,
# self.mp_group_endpoints, self.mp_rank, self.mp_group_id, True)
# dp
if self.hybrid_dp:
......@@ -757,7 +761,7 @@ class ShardingOptimizer(MetaOptimizerBase):
logging.info("Using Sharing&DP mode !")
else:
if self._as_outer_parallelism:
if self._as_outer_parallelism and not self.use_pipeline:
self.sharding_ring_id = 1
assert self.global_word_size > self._inner_parallelism_size, \
"global_word_size: {} should be larger than inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size)
......@@ -801,75 +805,113 @@ class ShardingOptimizer(MetaOptimizerBase):
# logging.info("megatron endpoints: {}".format(
# magetron_endpoints))
if self.use_pipeline:
self.sharding_ring_id = 0
self.sharding_group_size = self.user_defined_strategy.sharding_configs[
'sharding_group_size']
self.sharding_rank = self.global_rank % self.sharding_group_size
assert self.sharding_group_size * self.pipeline_nodes == self.role_maker._worker_num(
)
self.pp_ring_id = 1
self.pp_rank = self.global_rank // self.sharding_group_size
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx // self.sharding_group_size) == self.pp_rank
]
self.pp_group_size = self.pipeline_nodes
self.pp_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx % self.sharding_group_size) == self.sharding_rank
]
if self._inner_parallelism_size == 1:
self.sharding_ring_id = 0
self.sharding_group_size = self.user_defined_strategy.sharding_configs[
'sharding_group_size']
self.sharding_rank = self.global_rank % self.sharding_group_size
assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num(
)
self.pp_ring_id = 2
self.pp_rank = self.global_rank // (
self.sharding_group_size * self._inner_parallelism_size)
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx // self.sharding_group_size) == self.pp_rank
]
self.pp_group_size = self.pipeline_nodes
self.pp_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx % self.sharding_group_size
) == self.sharding_rank
]
else:
self.sharding_ring_id = 1
self.pp_ring_id = 2
# self.cards_per_node = 8
self.sharding_group_size = self.user_defined_strategy.sharding_configs[
'sharding_group_size']
self.sharding_rank = self.global_rank // self._inner_parallelism_size % self.sharding_group_size
# self.sharding_group_id = self.global_rank // (self._inner_parallelism_size % self.sharding_group_size)
self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.endpoints)
if (idx // self._inner_parallelism_size %
self.sharding_group_size) == self.sharding_rank
]
assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num(
)
self.pp_rank = self.global_rank // (
self.sharding_group_size * self._inner_parallelism_size)
offset = self.sharding_group_size * self._inner_parallelism_size
idx_with_pp_0 = self.global_rank % (
self.sharding_group_size * self._inner_parallelism_size)
self.pp_group_endpoints = []
for i in range(self.pipeline_nodes):
self.pp_group_endpoints.append(self.endpoints[
idx_with_pp_0])
idx_with_pp_0 += offset
#self.pp_group_endpoints = [
# ep for idx, ep in enumerate(self.endpoints)
# if (idx % self.sharding_group_size) == self.sharding_rank
#]
self.mp_group_id = 1
self.mp_rank = self.global_rank
self.mp_group_size = self.role_maker._worker_num()
self.mp_group_endpoints = self.endpoints[:]
logging.info("Using Sharing as Outer parallelism mode !")
self.dp_ring_id = -1
self.dp_rank = -1
self.dp_group_size = None
self.dp_group_endpoints = None
logging.info("Using Sharing with pipeline !")
else:
self.sharding_ring_id = 0
self.sharding_rank = self.global_rank
self.sharding_group_size = self.role_maker._worker_num()
self.sharding_group_endpoints = self.endpoints
#else:
# self.sharding_ring_id = 0
# self.sharding_rank = self.global_rank
# self.sharding_group_size = self.role_maker._worker_num()
# self.sharding_group_endpoints = self.endpoints
# sharding parallelism is the only model parallelism in the current setting
self.mp_group_id = self.sharding_ring_id
self.mp_rank = self.sharding_rank
self.mp_group_size = self.sharding_group_size
self.mp_group_endpoints = self.sharding_group_endpoints[:]
# # sharding parallelism is the only model parallelism in the current setting
# self.mp_group_id = self.sharding_ring_id
# self.mp_rank = self.sharding_rank
# self.mp_group_size = self.sharding_group_size
# self.mp_group_endpoints = self.sharding_group_endpoints[:]
logging.info("Using Sharing alone mode !")
# logging.info("Using Sharing alone mode !")
self.dp_ring_id = -1
self.dp_rank = -1
self.dp_group_size = None
self.dp_group_endpoints = None
self.pp_ring_id = -1
self.pp_rank = -1
self.pp_group_size = None
self.pp_group_endpoints = None
self.dp_ring_id = -1
self.dp_rank = -1
self.dp_group_size = None
self.dp_group_endpoints = None
#self.pp_ring_id = -1
#self.pp_rank = -1
#self.pp_group_size = None
#self.pp_group_endpoints = None
#self.dp_ring_id = -1
#self.dp_rank = -1
#self.dp_group_size = None
#self.dp_group_endpoints = None
logging.info("Using Sharing alone mode !")
logging.info("global word size: {}".format(self.global_word_size))
logging.info("global rank: {}".format(self.global_rank))
logging.info("sharding group_size: {}".format(self.sharding_group_size))
logging.info("sharding rank: {}".format(self.sharding_rank))
logging.info("current model parallelism group_size: {}".format(
self.mp_group_size))
logging.info("current model parallelism rank: {}".format(self.mp_rank))
logging.info("dp group size: {}".format(self.dp_group_size))
logging.info("dp rank: {}".format(self.dp_rank))
logging.info("current endpoint: {}".format(self.current_endpoint))
logging.info("global word endpoints: {}".format(self.endpoints))
logging.info("sharding group endpoints: {}".format(
self.sharding_group_endpoints))
logging.info("current model parallelism group endpoints: {}".format(
self.mp_group_endpoints))
logging.info("dp group endpoints: {}".format(self.dp_group_endpoints))
#logging.info("global word size: {}".format(self.global_word_size))
#logging.info("global rank: {}".format(self.global_rank))
#logging.info("sharding group_size: {}".format(self.sharding_group_size))
#logging.info("sharding rank: {}".format(self.sharding_rank))
#logging.info("current model parallelism group_size: {}".format(
# self.mp_group_size))
#logging.info("current model parallelism rank: {}".format(self.mp_rank))
#logging.info("dp group size: {}".format(self.dp_group_size))
#logging.info("dp rank: {}".format(self.dp_rank))
#logging.info("current endpoint: {}".format(self.current_endpoint))
#logging.info("global word endpoints: {}".format(self.endpoints))
#logging.info("sharding group endpoints: {}".format(
# self.sharding_group_endpoints))
#logging.info("current model parallelism group endpoints: {}".format(
# self.mp_group_endpoints))
#logging.info("dp group endpoints: {}".format(self.dp_group_endpoints))
return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册