未验证 提交 a1603797 编写于 作者: W WangXi 提交者: GitHub

[hybrid] refine sharding code (#34678)

上级 f30a5c42
...@@ -84,27 +84,23 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -84,27 +84,23 @@ class ShardingOptimizer(MetaOptimizerBase):
dist_strategy.sharding = True dist_strategy.sharding = True
dist_strategy.sharding_configs = {"segment_broadcast_MB": 32} dist_strategy.sharding_configs = {"segment_broadcast_MB": 32}
def minimize_impl(self, def _get_sharding_segment_strategy(self):
loss, """ get
startup_program=None, self._sharding_segment_strategy
parameter_list=None, 1. if by_size: self._broadcast_MB
no_grad_set=None): 2. if by_anchors: self._sharding_segment_anchors
# TODO: (JZ-LIANG) support multiple comm in future self._backward_remain_anchors
# self._nrings = self.user_defined_strategy.nccl_comm_num self._forward_remain_anchors
self._nrings_sharding = 1 """
self._nrings_dp = 1 strategy = self.user_defined_strategy
sharding_configs = strategy.sharding_configs
segment_strategy = str(sharding_configs["sharding_segment_strategy"])
# segment if segment_strategy == "segment_broadcast_MB":
self._sharding_segment_strategy = str( self._broadcast_MB = sharding_configs["segment_broadcast_MB"]
self.user_defined_strategy.sharding_configs[
"sharding_segment_strategy"])
if self._sharding_segment_strategy == "segment_broadcast_MB":
self._broadcast_MB = self.user_defined_strategy.sharding_configs[
"segment_broadcast_MB"]
assert self._broadcast_MB > 0, "segment size should larger than zero !" assert self._broadcast_MB > 0, "segment size should larger than zero !"
elif self._sharding_segment_strategy == "segment_anchors": elif segment_strategy == "segment_anchors":
self._sharding_segment_anchors = self.user_defined_strategy.sharding_configs[ self._sharding_segment_anchors = sharding_configs["segment_anchors"]
"segment_anchors"]
assert len(self._sharding_segment_anchors assert len(self._sharding_segment_anchors
) > 0, "you should set the sharding segment anchors !" ) > 0, "you should set the sharding segment anchors !"
self._backward_remain_anchors = self._sharding_segment_anchors[:] self._backward_remain_anchors = self._sharding_segment_anchors[:]
...@@ -112,82 +108,104 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -112,82 +108,104 @@ class ShardingOptimizer(MetaOptimizerBase):
else: else:
raise NotImplementedError( raise NotImplementedError(
"the sharding segment strategy [{}] is not implemented".format( "the sharding segment strategy [{}] is not implemented".format(
str(self._sharding_segment_strategy))) str(segment_strategy)))
self._sharding_segment_strategy = segment_strategy
def _get_hybrid_degree(self):
""" get
self.hybrid_dp
self.sharding_degree
self.mp_degree
self.pp_degree
self.dp_degree
"""
strategy = self.user_defined_strategy
sharding_configs = strategy.sharding_configs
# parallelism # parallelism
self.sharding_degree = int(self.user_defined_strategy.sharding_configs[ sharding_degree = int(sharding_configs["sharding_degree"])
"sharding_degree"]) mp_degree = int(sharding_configs["mp_degree"])
assert self.sharding_degree > 0, "sharding degree must be larger than zero" pp_degree = int(sharding_configs["pp_degree"])
self.mp_degree = int(self.user_defined_strategy.sharding_configs[ dp_degree = int(sharding_configs['dp_degree'])
"mp_degree"]) global_world_size = self.role_maker._worker_num()
assert sharding_degree > 0, "sharding degree must be larger than zero"
# pipeline setting # pipeline setting
# TODO (JZ-LIANG) should revise here for support mix parallelism with pipeline # TODO (JZ-LIANG) should revise here for support mix parallelism with pipeline
self.pp_degree = int(self.user_defined_strategy.sharding_configs[ if pp_degree > 1:
"pp_degree"]) assert strategy.pipeline is True
if self.pp_degree > 1:
assert self.user_defined_strategy.pipeline == True assert global_world_size == mp_degree * sharding_degree * pp_degree * dp_degree, \
"global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format(
self.dp_degree = int(self.user_defined_strategy.sharding_configs[ global_world_size, mp_degree, sharding_degree, pp_degree, dp_degree)
'dp_degree'])
assert self.role_maker._worker_num(
) == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format(
self.role_maker._worker_num(),
self.mp_degree,
self.sharding_degree,
self.pp_degree,
self.dp_degree, )
# FIXME (JZ-LIANG) deprecated hybrid_dp # FIXME (JZ-LIANG) deprecated hybrid_dp
if self.user_defined_strategy.sharding_configs["hybrid_dp"]: if sharding_configs["hybrid_dp"]:
logger.warning( logger.warning(
"[hybrid_dp] API setting is deprecated. Now when dp_degree >= 2, its will be in hybrid dp mode automatically" "[hybrid_dp] API setting is deprecated. Now when "
) "dp_degree >= 2, its will be in hybrid dp mode automatically")
assert self.dp_degree >= 1 assert dp_degree >= 1
if self.dp_degree > 1:
self.hybrid_dp = True self.hybrid_dp = True if dp_degree > 1 else False
else: self.sharding_degree = sharding_degree
self.hybrid_dp = False self.mp_degree = mp_degree
self.pp_degree = pp_degree
self.dp_degree = dp_degree
def _get_hybrid_dp_mode(self):
""" get
self.hybrid_dp_mode
self.gradient_merge_mode
self._gradient_merge_acc_step
self.pp_allreduce_in_optimize
"""
strategy = self.user_defined_strategy
sharding_configs = strategy.sharding_configs
# NOTE (JZ-LIANG) # NOTE (JZ-LIANG)
# there 2 kind of modes for gradient-merge and hybrid-dp in mixed parallism [sharding] and [pipeline]. # There 2 kind of modes for gradient-merge and hybrid-dp in mixed parallelism [sharding] and [pipeline].
# we distinguish this two modes since the gm/hybrid-dp related allreduce should be insert in different place according different mode to have best performance: # We distinguish this two modes since the gm/hybrid-dp related allreduce should be insert in different place
# sharding: communication within node, and therefore should insert within backward segment to overlap with bw calc, conduct every micro step # according different mode to have best performance:
# pipeline: communication accross nodes, and therefore should insert in update segemnt, conduct just once per global step # sharding: communication within node, and therefore should insert within backward segment
self.hybrid_dp_mode = None # to overlap with bw calc, conduct every micro step.
# pipeline: communication across nodes, and therefore should insert in update segment,
# conduct just once per global step.
dp_mode = None
# dp here is the pure dp as the outest parallelism # dp here is the pure dp as the outest parallelism
if self.hybrid_dp: if self.hybrid_dp:
assert self.dp_degree > 1, "hybrid dp is on, but dp degree is [{}]".format(
self.dp_degree)
if self.pp_degree > 1: if self.pp_degree > 1:
self.hybrid_dp_mode = "pp_hybrid_dp" dp_mode = "pp_hybrid_dp"
else: else:
assert self.sharding_degree > 1, "by now we only support five kind of hybrid dp: sharding_hybrid_dp, mp_sharding_hybrid_dp, pp_hybrid_dp, mp_sharding_pp_hybrid_dp, sharding_pp_hybrid_dp." assert self.sharding_degree > 1, \
self.hybrid_dp_mode = "sharding_hybrid_dp" "by now we only support five kind of hybrid dp: sharding_hybrid_dp, " \
"mp_sharding_hybrid_dp, pp_hybrid_dp, mp_sharding_pp_hybrid_dp, sharding_pp_hybrid_dp."
dp_mode = "sharding_hybrid_dp"
# gradient merge # gradient merge
self._gradient_merge_acc_step = int( gm_mode = None
self.user_defined_strategy.sharding_configs[ gm_acc_step = int(sharding_configs["gradient_merge_acc_step"])
"gradient_merge_acc_step"])
self.gradient_merge_mode = None
if self.pp_degree <= 1: if self.pp_degree <= 1:
self.gradient_merge_mode = "sharding_gm" gm_mode = "sharding_gm"
self._grad2merged_grad = dict() self._grad2merged_grad = dict()
else: else:
self.gradient_merge_mode = "pp_gm" gm_mode = "pp_gm"
self._gradient_merge_acc_step = self.user_defined_strategy.pipeline_configs[ gm_acc_step = strategy.pipeline_configs['accumulate_steps']
'accumulate_steps'] if gm_acc_step > 1:
if self._gradient_merge_acc_step > 1:
logger.info("Gradient merge in [{}], acc step = [{}]".format( logger.info("Gradient merge in [{}], acc step = [{}]".format(
self.gradient_merge_mode, self._gradient_merge_acc_step)) gm_mode, gm_acc_step))
# optimize offload self.hybrid_dp_mode = dp_mode
self.optimize_offload = self.user_defined_strategy.sharding_configs[ self.gradient_merge_mode = gm_mode
"optimize_offload"] self._gradient_merge_acc_step = gm_acc_step
# this feature is design for ascend, and should NOT be used in GPU training # this feature is design for ascend, and should NOT be used in GPU training
self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[ self.pp_allreduce_in_optimize = sharding_configs[
"pp_allreduce_in_optimize"] "pp_allreduce_in_optimize"]
def _inner_opt_minimize(self, loss, startup_program, parameter_list,
no_grad_set):
pipeline_configs = self.user_defined_strategy.pipeline_configs
if self.inner_opt is None: if self.inner_opt is None:
raise ValueError( raise ValueError(
"self.inner_opt of ShardingOptimizer should not be None.") "self.inner_opt of ShardingOptimizer should not be None.")
...@@ -195,32 +213,29 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -195,32 +213,29 @@ class ShardingOptimizer(MetaOptimizerBase):
if self.pp_degree > 1: if self.pp_degree > 1:
pp_optimizer = fluid.optimizer.PipelineOptimizer( pp_optimizer = fluid.optimizer.PipelineOptimizer(
self.inner_opt, self._gradient_merge_acc_step) self.inner_opt, self._gradient_merge_acc_step)
self._pp_optimizer = pp_optimizer
strategy = self.user_defined_strategy global_rank = self.role_maker._worker_index()
self.schedule_mode = strategy.pipeline_configs['schedule_mode'] schedule_mode = pipeline_configs['schedule_mode']
self.pp_rank_ = self.role_maker._worker_index() // (
self.sharding_degree * self.mp_degree) % self.pp_degree
pipeline_opt = dict()
pipeline_opt['schedule_mode'] = self.schedule_mode
pipeline_opt['micro_batch_size'] = strategy.pipeline_configs[
'micro_batch_size']
pipeline_opt['local_rank'] = self.pp_rank_
pipeline_opt['global_rank'] = self.role_maker._worker_index()
pipeline_opt['use_sharding'] = True
# TODO (JZ-LIANG) should revise here for support mix parallelism with pipeline
pipeline_opt['ring_id'] = 20
pipeline_opt['global_ring_id'] = 3
pipeline_opt['mp_degree'] = self.mp_degree
pipeline_opt['mp_rank'] = self.role_maker._worker_index(
) % self.mp_degree
pipeline_opt = {
'schedule_mode': schedule_mode,
'micro_batch_size': pipeline_configs['micro_batch_size'],
'local_rank': self.pp_rank,
'global_rank': global_rank,
'use_sharding': True,
# TODO (JZ-LIANG) should revise here for support mix parallelism with pipeline
'ring_id': 20,
'global_ring_id': 3,
'mp_degree': self.mp_degree,
'mp_rank': global_rank % self.mp_degree,
}
main_program = loss.block.program main_program = loss.block.program
main_program._pipeline_opt = pipeline_opt main_program._pipeline_opt = pipeline_opt
optimize_ops, params_grads, program_list, self.pipeline_pair, self.pp_ring_map = pp_optimizer.minimize( optimize_ops, params_grads, program_list, self.pipeline_pair, self.pp_ring_map = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set) loss, startup_program, parameter_list, no_grad_set)
self.pp_degree = len(program_list) assert self.pp_degree == len(program_list)
else: else:
optimize_ops, params_grads = self.inner_opt.minimize( optimize_ops, params_grads = self.inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set) loss, startup_program, parameter_list, no_grad_set)
...@@ -230,9 +245,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -230,9 +245,8 @@ class ShardingOptimizer(MetaOptimizerBase):
if self.pp_degree > 1: if self.pp_degree > 1:
startup_program = startup_program._pipeline_opt['startup_program'] startup_program = startup_program._pipeline_opt['startup_program']
#main_program = main_program._pipeline_opt['section_program']['program'] print("pp_rank:", self.pp_rank)
print("pp_rank:", self.pp_rank_) main_program = program_list[self.pp_rank]
main_program = program_list[self.pp_rank_]
with open("main_%d" % self.role_maker._worker_index(), 'w') as f: with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
f.writelines(str(main_program)) f.writelines(str(main_program))
main_block = main_program.global_block() main_block = main_program.global_block()
...@@ -241,7 +255,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -241,7 +255,6 @@ class ShardingOptimizer(MetaOptimizerBase):
if main_block.has_var(param.name): if main_block.has_var(param.name):
new_params_grads.append((param, grad)) new_params_grads.append((param, grad))
params_grads = new_params_grads params_grads = new_params_grads
else: else:
main_block = loss.block main_block = loss.block
...@@ -254,10 +267,13 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -254,10 +267,13 @@ class ShardingOptimizer(MetaOptimizerBase):
with open("main_%d" % self.role_maker._worker_index(), 'w') as f: with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
f.writelines(str(main_program)) f.writelines(str(main_program))
# step0: _init_comm return optimize_ops, params_grads
self._init_comm()
if self.sharding_degree > 1: def _apply_sharding_pass(self, params_grads):
if self.sharding_degree == 1: return
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
# step1: build shard # step1: build shard
self._build_shard(params_grads) self._build_shard(params_grads)
...@@ -270,13 +286,17 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -270,13 +286,17 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block._sync_with_cpp() main_block._sync_with_cpp()
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
main_block._sync_with_cpp()
# step4: remove unneeded ops and vars from block # step4: remove unneeded ops and vars from block
self._prune_main_program(main_block) self._prune_main_program(main_block)
self._prune_startup_program(startup_block) self._prune_startup_program(startup_block)
if self.pp_degree > 1: def _insert_allreduce_for_pp(self):
if self.pp_degree == 1: return
strategy = self.user_defined_strategy
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
# sharding-pp related logic # sharding-pp related logic
# pp_optimizer._rename_gradient_var_name(main_block) # pp_optimizer._rename_gradient_var_name(main_block)
# crop ops # crop ops
...@@ -296,14 +316,14 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -296,14 +316,14 @@ class ShardingOptimizer(MetaOptimizerBase):
if in_name not in main_block.vars: if in_name not in main_block.vars:
main_block._remove_op(idx) main_block._remove_op(idx)
accumulated_grad_names = pp_optimizer._accumulate_gradients( accumulated_grad_names = self._pp_optimizer._accumulate_gradients(
main_block) main_block)
# accumulated_grad_names = sorted(accumulated_grad_names) # accumulated_grad_names = sorted(accumulated_grad_names)
if self.pp_allreduce_in_optimize: if self.pp_allreduce_in_optimize:
print("persistable FP32 grad: ") print("persistable FP32 grad: ")
print(accumulated_grad_names) print(accumulated_grad_names)
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx( first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block, raise_error=self.user_defined_strategy.amp) main_block, raise_error=strategy.amp)
insert_reduce_ops( insert_reduce_ops(
main_block, main_block,
first_optimize_op_index, first_optimize_op_index,
...@@ -314,7 +334,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -314,7 +334,7 @@ class ShardingOptimizer(MetaOptimizerBase):
use_calc_stream=True) use_calc_stream=True)
if self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp": if self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp":
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx( first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block, raise_error=self.user_defined_strategy.amp) main_block, raise_error=strategy.amp)
if first_optimize_op_index >= 0: if first_optimize_op_index >= 0:
insert_allreduce_ops( insert_allreduce_ops(
main_block, main_block,
...@@ -323,23 +343,29 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -323,23 +343,29 @@ class ShardingOptimizer(MetaOptimizerBase):
accumulated_grad_names, accumulated_grad_names,
core.op_proto_and_checker_maker.OpRole.Optimize, core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True, use_calc_stream=True,
user_defined_strategy=self.user_defined_strategy) user_defined_strategy=strategy)
def _adapt_amp_clip_without_sharding(self):
if self.sharding_degree > 1: return
# if not use sharding, adapt amp/clip, for remain parallelism. # if not use sharding, adapt amp/clip, for remain parallelism.
# cast --> amp --> clip --> opt # cast --> amp --> clip --> opt
if self.sharding_degree <= 1:
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
# FIXME(wangxi): mp should prune duplicated param_grads when calc # FIXME(wangxi): mp should prune duplicated param_grads when calc
# amp inf_var & clip global_norm_var # amp inf_var & clip global_norm_var
# amp FP16Utils.sync_amp_check_nan_inf(main_block,
FP16Utils.sync_amp_check_nan_inf( [self.mp_ring_id, self.pp_ring_id])
main_block, [self.mp_ring_id, self.pp_ring_id])
# clip
gradientclip_helper = GradientClipHelper(None) gradientclip_helper = GradientClipHelper(None)
gradientclip_helper.sync_global_norm( gradientclip_helper.sync_global_norm(
main_block, [self.mp_ring_id, self.pp_ring_id]) main_block, [self.mp_ring_id, self.pp_ring_id])
def _insert_loss_grad_scale_op(self):
main_block = self._main_program.global_block()
# step6: loss div dp_degree # step6: loss div dp_degree
global_dp_degree = self.sharding_degree * self.dp_degree global_dp_degree = self.sharding_degree * self.dp_degree
assert int(global_dp_degree) == global_dp_degree assert int(global_dp_degree) == global_dp_degree
...@@ -348,18 +374,67 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -348,18 +374,67 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block._sync_with_cpp() main_block._sync_with_cpp()
# TODO(wangxi): add optimize offload def _apply_optimize_offload_pass(self):
# opt offload should be enable while gradient merge is enable && acc_step is quite large (e.g. >> 100) strategy = self.user_defined_strategy
# sync its memcpy could not be overlap with calc, otherwise it will slower down training severely. sharding_configs = strategy.sharding_configs
if self.optimize_offload: main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
# optimize offload should be enable while gradient merge is enable and
# acc_step is quite large (e.g. >> 100). Since its memcpy could not be
# overlap with calc, otherwise it will slower down training severely.
if sharding_configs["optimize_offload"]:
logger.info("Sharding with optimize offload !") logger.info("Sharding with optimize offload !")
offload_helper = OffloadHelper() offload_helper = OffloadHelper()
offload_helper.offload(main_block, startup_block) offload_helper.offload(main_block, startup_block)
offload_helper.offload_fp32param(main_block, startup_block) offload_helper.offload_fp32param(main_block, startup_block)
def _dump_program_for_debug(self):
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
with open("start_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(startup_block.program))
with open("main_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(main_block.program))
def minimize_impl(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
# TODO: (JZ-LIANG) support multiple comm in future
# self._nrings = self.user_defined_strategy.nccl_comm_num
self._nrings_sharding = 1
self._nrings_dp = 1
self._get_sharding_segment_strategy()
self._get_hybrid_degree()
self._get_hybrid_dp_mode()
# config sharding & dp groups
self._build_groups()
# inner optimize minimize
optimize_ops, params_grads = self._inner_opt_minimize(
loss, startup_program, parameter_list, no_grad_set)
self._init_comm()
self._apply_sharding_pass(params_grads)
self._insert_allreduce_for_pp()
self._adapt_amp_clip_without_sharding()
# loss div dp_degree
self._insert_loss_grad_scale_op()
self._apply_optimize_offload_pass()
# step6: (optional) sharding gradient merge # step6: (optional) sharding gradient merge
if self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: self._sharding_gradient_merge()
self._sharding_gradient_merge(main_block)
# # check op dependecy # # check op dependecy
# FIXME (JZ-LIANG) enable checking in future. # FIXME (JZ-LIANG) enable checking in future.
...@@ -367,17 +442,11 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -367,17 +442,11 @@ class ShardingOptimizer(MetaOptimizerBase):
# check_allreduce_sum(main_block, self._shard, self.sharding_ring_id, # check_allreduce_sum(main_block, self._shard, self.sharding_ring_id,
# self.dp_ring_id) # self.dp_ring_id)
if self.hybrid_dp:
# NOTE(JZ-LIANG) ensure in both sharding_hybrid_dp & pp_hybrid_dp # NOTE(JZ-LIANG) ensure in both sharding_hybrid_dp & pp_hybrid_dp
# init param broadcast should be called after startup pruning # init param broadcast should be called after startup pruning
self._initialization_broadcast(startup_block) self._initialization_broadcast()
with open("start_sharding_%d" % self.role_maker._worker_index(), self._dump_program_for_debug()
'w') as f:
f.writelines(str(startup_block.program))
with open("main_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(main_block.program))
# GPU need to wait server ready, GPU and NPU is Layered connection # GPU need to wait server ready, GPU and NPU is Layered connection
if not core.is_compiled_with_npu(): if not core.is_compiled_with_npu():
...@@ -471,9 +540,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -471,9 +540,6 @@ class ShardingOptimizer(MetaOptimizerBase):
def _init_pipeline_comm(self, startup_block): def _init_pipeline_comm(self, startup_block):
# TODO (JZ-LIANG) to unify pp_rank_ and pp_rank # TODO (JZ-LIANG) to unify pp_rank_ and pp_rank
assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
self.pp_rank_, self.pp_rank)
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self._startup_program,
self.current_endpoint, self.current_endpoint,
...@@ -496,17 +562,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -496,17 +562,8 @@ class ShardingOptimizer(MetaOptimizerBase):
self._init_pair_comm(pair, ring_id) self._init_pair_comm(pair, ring_id)
def _init_comm(self): def _init_comm(self):
# config sharding & dp groups
self._build_groups()
# sync var # sync var
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
self.startup_prog_sync_var = startup_block.create_var(
name="startup_prog_sync_var",
shape=[1],
dtype=core.VarDesc.VarType.INT32,
persistable=False)
# mp ring # mp ring
if self.mp_degree > 1: if self.mp_degree > 1:
...@@ -1051,7 +1108,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1051,7 +1108,8 @@ class ShardingOptimizer(MetaOptimizerBase):
sharding: 1 sharding: 1
pure-dp: 2 pure-dp: 2
global: 3 global: 3
pp: >= 20 pp: 4
pp-pair: >= 20
if one parallelism is not enable: -1 if one parallelism is not enable: -1
and only support parallelism hierarchy: mp --> sharding --> pp --> dp and only support parallelism hierarchy: mp --> sharding --> pp --> dp
""" """
...@@ -1216,11 +1274,16 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1216,11 +1274,16 @@ class ShardingOptimizer(MetaOptimizerBase):
return return
def _initialization_broadcast(self, startup_block): def _initialization_broadcast(self):
""" """
this funtion is to ensure the initialization between dp group to be this funtion is to ensure the initialization between dp group to be
identical when hybrid-dp is used. identical when hybrid-dp is used.
""" """
if not self.hybrid_dp:
return
startup_block = self._startup_program.global_block()
params = [] params = []
for param in startup_block.iter_parameters(): for param in startup_block.iter_parameters():
params.append(param) params.append(param)
...@@ -1461,13 +1524,17 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1461,13 +1524,17 @@ class ShardingOptimizer(MetaOptimizerBase):
# lr_var = main_block.var("gradient_merge_current_step") # lr_var = main_block.var("gradient_merge_current_step")
# paddle.static.Print(lr_var, message="in OPTIMIZE last conditional") # paddle.static.Print(lr_var, message="in OPTIMIZE last conditional")
def _sharding_gradient_merge(self, main_block): def _sharding_gradient_merge(self):
""" """
copy all optimize ops in origin main block copy all optimize ops in origin main block
remove all optimize ops in origin main block remove all optimize ops in origin main block
create cond block create cond block
""" """
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1:
return
main_block = self._main_program.global_block()
# copy original optimize ops to temp ops desc list # copy original optimize ops to temp ops desc list
# remove them from block 0 # remove them from block 0
tmp_copy_block = self._main_program._create_block() tmp_copy_block = self._main_program._create_block()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册