未验证 提交 a1603797 编写于 作者: W WangXi 提交者: GitHub

[hybrid] refine sharding code (#34678)

上级 f30a5c42
......@@ -84,27 +84,23 @@ class ShardingOptimizer(MetaOptimizerBase):
dist_strategy.sharding = True
dist_strategy.sharding_configs = {"segment_broadcast_MB": 32}
def minimize_impl(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
# TODO: (JZ-LIANG) support multiple comm in future
# self._nrings = self.user_defined_strategy.nccl_comm_num
self._nrings_sharding = 1
self._nrings_dp = 1
def _get_sharding_segment_strategy(self):
""" get
self._sharding_segment_strategy
1. if by_size: self._broadcast_MB
2. if by_anchors: self._sharding_segment_anchors
self._backward_remain_anchors
self._forward_remain_anchors
"""
strategy = self.user_defined_strategy
sharding_configs = strategy.sharding_configs
segment_strategy = str(sharding_configs["sharding_segment_strategy"])
# segment
self._sharding_segment_strategy = str(
self.user_defined_strategy.sharding_configs[
"sharding_segment_strategy"])
if self._sharding_segment_strategy == "segment_broadcast_MB":
self._broadcast_MB = self.user_defined_strategy.sharding_configs[
"segment_broadcast_MB"]
if segment_strategy == "segment_broadcast_MB":
self._broadcast_MB = sharding_configs["segment_broadcast_MB"]
assert self._broadcast_MB > 0, "segment size should larger than zero !"
elif self._sharding_segment_strategy == "segment_anchors":
self._sharding_segment_anchors = self.user_defined_strategy.sharding_configs[
"segment_anchors"]
elif segment_strategy == "segment_anchors":
self._sharding_segment_anchors = sharding_configs["segment_anchors"]
assert len(self._sharding_segment_anchors
) > 0, "you should set the sharding segment anchors !"
self._backward_remain_anchors = self._sharding_segment_anchors[:]
......@@ -112,82 +108,104 @@ class ShardingOptimizer(MetaOptimizerBase):
else:
raise NotImplementedError(
"the sharding segment strategy [{}] is not implemented".format(
str(self._sharding_segment_strategy)))
str(segment_strategy)))
self._sharding_segment_strategy = segment_strategy
def _get_hybrid_degree(self):
""" get
self.hybrid_dp
self.sharding_degree
self.mp_degree
self.pp_degree
self.dp_degree
"""
strategy = self.user_defined_strategy
sharding_configs = strategy.sharding_configs
# parallelism
self.sharding_degree = int(self.user_defined_strategy.sharding_configs[
"sharding_degree"])
assert self.sharding_degree > 0, "sharding degree must be larger than zero"
self.mp_degree = int(self.user_defined_strategy.sharding_configs[
"mp_degree"])
sharding_degree = int(sharding_configs["sharding_degree"])
mp_degree = int(sharding_configs["mp_degree"])
pp_degree = int(sharding_configs["pp_degree"])
dp_degree = int(sharding_configs['dp_degree'])
global_world_size = self.role_maker._worker_num()
assert sharding_degree > 0, "sharding degree must be larger than zero"
# pipeline setting
# TODO (JZ-LIANG) should revise here for support mix parallelism with pipeline
self.pp_degree = int(self.user_defined_strategy.sharding_configs[
"pp_degree"])
if self.pp_degree > 1:
assert self.user_defined_strategy.pipeline == True
self.dp_degree = int(self.user_defined_strategy.sharding_configs[
'dp_degree'])
assert self.role_maker._worker_num(
) == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format(
self.role_maker._worker_num(),
self.mp_degree,
self.sharding_degree,
self.pp_degree,
self.dp_degree, )
if pp_degree > 1:
assert strategy.pipeline is True
assert global_world_size == mp_degree * sharding_degree * pp_degree * dp_degree, \
"global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format(
global_world_size, mp_degree, sharding_degree, pp_degree, dp_degree)
# FIXME (JZ-LIANG) deprecated hybrid_dp
if self.user_defined_strategy.sharding_configs["hybrid_dp"]:
if sharding_configs["hybrid_dp"]:
logger.warning(
"[hybrid_dp] API setting is deprecated. Now when dp_degree >= 2, its will be in hybrid dp mode automatically"
)
assert self.dp_degree >= 1
if self.dp_degree > 1:
self.hybrid_dp = True
else:
self.hybrid_dp = False
# NOTE (JZ-LIANG)
# there 2 kind of modes for gradient-merge and hybrid-dp in mixed parallism [sharding] and [pipeline].
# we distinguish this two modes since the gm/hybrid-dp related allreduce should be insert in different place according different mode to have best performance:
# sharding: communication within node, and therefore should insert within backward segment to overlap with bw calc, conduct every micro step
# pipeline: communication accross nodes, and therefore should insert in update segemnt, conduct just once per global step
self.hybrid_dp_mode = None
"[hybrid_dp] API setting is deprecated. Now when "
"dp_degree >= 2, its will be in hybrid dp mode automatically")
assert dp_degree >= 1
self.hybrid_dp = True if dp_degree > 1 else False
self.sharding_degree = sharding_degree
self.mp_degree = mp_degree
self.pp_degree = pp_degree
self.dp_degree = dp_degree
def _get_hybrid_dp_mode(self):
""" get
self.hybrid_dp_mode
self.gradient_merge_mode
self._gradient_merge_acc_step
self.pp_allreduce_in_optimize
"""
strategy = self.user_defined_strategy
sharding_configs = strategy.sharding_configs
# NOTE (JZ-LIANG)
# There 2 kind of modes for gradient-merge and hybrid-dp in mixed parallelism [sharding] and [pipeline].
# We distinguish this two modes since the gm/hybrid-dp related allreduce should be insert in different place
# according different mode to have best performance:
# sharding: communication within node, and therefore should insert within backward segment
# to overlap with bw calc, conduct every micro step.
# pipeline: communication across nodes, and therefore should insert in update segment,
# conduct just once per global step.
dp_mode = None
# dp here is the pure dp as the outest parallelism
if self.hybrid_dp:
assert self.dp_degree > 1, "hybrid dp is on, but dp degree is [{}]".format(
self.dp_degree)
if self.pp_degree > 1:
self.hybrid_dp_mode = "pp_hybrid_dp"
dp_mode = "pp_hybrid_dp"
else:
assert self.sharding_degree > 1, "by now we only support five kind of hybrid dp: sharding_hybrid_dp, mp_sharding_hybrid_dp, pp_hybrid_dp, mp_sharding_pp_hybrid_dp, sharding_pp_hybrid_dp."
self.hybrid_dp_mode = "sharding_hybrid_dp"
assert self.sharding_degree > 1, \
"by now we only support five kind of hybrid dp: sharding_hybrid_dp, " \
"mp_sharding_hybrid_dp, pp_hybrid_dp, mp_sharding_pp_hybrid_dp, sharding_pp_hybrid_dp."
dp_mode = "sharding_hybrid_dp"
# gradient merge
self._gradient_merge_acc_step = int(
self.user_defined_strategy.sharding_configs[
"gradient_merge_acc_step"])
self.gradient_merge_mode = None
gm_mode = None
gm_acc_step = int(sharding_configs["gradient_merge_acc_step"])
if self.pp_degree <= 1:
self.gradient_merge_mode = "sharding_gm"
gm_mode = "sharding_gm"
self._grad2merged_grad = dict()
else:
self.gradient_merge_mode = "pp_gm"
self._gradient_merge_acc_step = self.user_defined_strategy.pipeline_configs[
'accumulate_steps']
if self._gradient_merge_acc_step > 1:
gm_mode = "pp_gm"
gm_acc_step = strategy.pipeline_configs['accumulate_steps']
if gm_acc_step > 1:
logger.info("Gradient merge in [{}], acc step = [{}]".format(
self.gradient_merge_mode, self._gradient_merge_acc_step))
gm_mode, gm_acc_step))
# optimize offload
self.optimize_offload = self.user_defined_strategy.sharding_configs[
"optimize_offload"]
self.hybrid_dp_mode = dp_mode
self.gradient_merge_mode = gm_mode
self._gradient_merge_acc_step = gm_acc_step
# this feature is design for ascend, and should NOT be used in GPU training
self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[
self.pp_allreduce_in_optimize = sharding_configs[
"pp_allreduce_in_optimize"]
def _inner_opt_minimize(self, loss, startup_program, parameter_list,
no_grad_set):
pipeline_configs = self.user_defined_strategy.pipeline_configs
if self.inner_opt is None:
raise ValueError(
"self.inner_opt of ShardingOptimizer should not be None.")
......@@ -195,32 +213,29 @@ class ShardingOptimizer(MetaOptimizerBase):
if self.pp_degree > 1:
pp_optimizer = fluid.optimizer.PipelineOptimizer(
self.inner_opt, self._gradient_merge_acc_step)
strategy = self.user_defined_strategy
self.schedule_mode = strategy.pipeline_configs['schedule_mode']
self.pp_rank_ = self.role_maker._worker_index() // (
self.sharding_degree * self.mp_degree) % self.pp_degree
pipeline_opt = dict()
pipeline_opt['schedule_mode'] = self.schedule_mode
pipeline_opt['micro_batch_size'] = strategy.pipeline_configs[
'micro_batch_size']
pipeline_opt['local_rank'] = self.pp_rank_
pipeline_opt['global_rank'] = self.role_maker._worker_index()
pipeline_opt['use_sharding'] = True
# TODO (JZ-LIANG) should revise here for support mix parallelism with pipeline
pipeline_opt['ring_id'] = 20
pipeline_opt['global_ring_id'] = 3
pipeline_opt['mp_degree'] = self.mp_degree
pipeline_opt['mp_rank'] = self.role_maker._worker_index(
) % self.mp_degree
self._pp_optimizer = pp_optimizer
global_rank = self.role_maker._worker_index()
schedule_mode = pipeline_configs['schedule_mode']
pipeline_opt = {
'schedule_mode': schedule_mode,
'micro_batch_size': pipeline_configs['micro_batch_size'],
'local_rank': self.pp_rank,
'global_rank': global_rank,
'use_sharding': True,
# TODO (JZ-LIANG) should revise here for support mix parallelism with pipeline
'ring_id': 20,
'global_ring_id': 3,
'mp_degree': self.mp_degree,
'mp_rank': global_rank % self.mp_degree,
}
main_program = loss.block.program
main_program._pipeline_opt = pipeline_opt
optimize_ops, params_grads, program_list, self.pipeline_pair, self.pp_ring_map = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set)
self.pp_degree = len(program_list)
assert self.pp_degree == len(program_list)
else:
optimize_ops, params_grads = self.inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set)
......@@ -230,9 +245,8 @@ class ShardingOptimizer(MetaOptimizerBase):
if self.pp_degree > 1:
startup_program = startup_program._pipeline_opt['startup_program']
#main_program = main_program._pipeline_opt['section_program']['program']
print("pp_rank:", self.pp_rank_)
main_program = program_list[self.pp_rank_]
print("pp_rank:", self.pp_rank)
main_program = program_list[self.pp_rank]
with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
f.writelines(str(main_program))
main_block = main_program.global_block()
......@@ -241,7 +255,6 @@ class ShardingOptimizer(MetaOptimizerBase):
if main_block.has_var(param.name):
new_params_grads.append((param, grad))
params_grads = new_params_grads
else:
main_block = loss.block
......@@ -254,93 +267,106 @@ class ShardingOptimizer(MetaOptimizerBase):
with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
f.writelines(str(main_program))
# step0: _init_comm
self._init_comm()
return optimize_ops, params_grads
if self.sharding_degree > 1:
def _apply_sharding_pass(self, params_grads):
if self.sharding_degree == 1: return
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
# step1: build shard
self._build_shard(params_grads)
# step1: build shard
self._build_shard(params_grads)
# step2: split_program
self._split_program(main_block)
# step2: split_program
self._split_program(main_block)
# step3: add broadcast and reduce ops
self._add_broadcast_allreduce(main_block)
main_block._sync_with_cpp()
startup_block._sync_with_cpp()
# step3: add broadcast and reduce ops
self._add_broadcast_allreduce(main_block)
main_block._sync_with_cpp()
startup_block._sync_with_cpp()
main_block._sync_with_cpp()
# step4: remove unneeded ops and vars from block
self._prune_main_program(main_block)
self._prune_startup_program(startup_block)
# step4: remove unneeded ops and vars from block
self._prune_main_program(main_block)
self._prune_startup_program(startup_block)
def _insert_allreduce_for_pp(self):
if self.pp_degree == 1: return
if self.pp_degree > 1:
# sharding-pp related logic
# pp_optimizer._rename_gradient_var_name(main_block)
# crop ops
if self.sharding_degree > 1:
for idx, op in reversed(list(enumerate(main_block.ops))):
if is_update_op(op):
op_role_var = op.attr('op_role_var')
param_name = op_role_var[0]
if not self._shard.has_param(param_name):
main_block._remove_op(idx)
for idx, op in reversed(list(enumerate(main_block.ops))):
if op.type != 'cast': continue
in_name = op.input_arg_names[0]
if in_name not in self._params: continue
#if self._shard.has_param(param_name): continue
if in_name not in main_block.vars:
strategy = self.user_defined_strategy
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
# sharding-pp related logic
# pp_optimizer._rename_gradient_var_name(main_block)
# crop ops
if self.sharding_degree > 1:
for idx, op in reversed(list(enumerate(main_block.ops))):
if is_update_op(op):
op_role_var = op.attr('op_role_var')
param_name = op_role_var[0]
if not self._shard.has_param(param_name):
main_block._remove_op(idx)
accumulated_grad_names = pp_optimizer._accumulate_gradients(
main_block)
# accumulated_grad_names = sorted(accumulated_grad_names)
if self.pp_allreduce_in_optimize:
print("persistable FP32 grad: ")
print(accumulated_grad_names)
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block, raise_error=self.user_defined_strategy.amp)
insert_reduce_ops(
for idx, op in reversed(list(enumerate(main_block.ops))):
if op.type != 'cast': continue
in_name = op.input_arg_names[0]
if in_name not in self._params: continue
#if self._shard.has_param(param_name): continue
if in_name not in main_block.vars:
main_block._remove_op(idx)
accumulated_grad_names = self._pp_optimizer._accumulate_gradients(
main_block)
# accumulated_grad_names = sorted(accumulated_grad_names)
if self.pp_allreduce_in_optimize:
print("persistable FP32 grad: ")
print(accumulated_grad_names)
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block, raise_error=strategy.amp)
insert_reduce_ops(
main_block,
first_optimize_op_index,
self.sharding_ring_id,
accumulated_grad_names,
self._shard,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True)
if self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp":
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block, raise_error=strategy.amp)
if first_optimize_op_index >= 0:
insert_allreduce_ops(
main_block,
first_optimize_op_index,
self.sharding_ring_id,
self.dp_ring_id,
accumulated_grad_names,
self._shard,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True)
if self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp":
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block, raise_error=self.user_defined_strategy.amp)
if first_optimize_op_index >= 0:
insert_allreduce_ops(
main_block,
first_optimize_op_index,
self.dp_ring_id,
accumulated_grad_names,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True,
user_defined_strategy=self.user_defined_strategy)
use_calc_stream=True,
user_defined_strategy=strategy)
def _adapt_amp_clip_without_sharding(self):
if self.sharding_degree > 1: return
# if not use sharding, adapt amp/clip, for remain parallelism.
# cast --> amp --> clip --> opt
if self.sharding_degree <= 1:
# FIXME(wangxi): mp should prune duplicated param_grads when calc
# amp inf_var & clip global_norm_var
# amp
FP16Utils.sync_amp_check_nan_inf(
main_block, [self.mp_ring_id, self.pp_ring_id])
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
# FIXME(wangxi): mp should prune duplicated param_grads when calc
# amp inf_var & clip global_norm_var
# clip
gradientclip_helper = GradientClipHelper(None)
gradientclip_helper.sync_global_norm(
main_block, [self.mp_ring_id, self.pp_ring_id])
FP16Utils.sync_amp_check_nan_inf(main_block,
[self.mp_ring_id, self.pp_ring_id])
# step6: loss div dp_degree
gradientclip_helper = GradientClipHelper(None)
gradientclip_helper.sync_global_norm(
main_block, [self.mp_ring_id, self.pp_ring_id])
def _insert_loss_grad_scale_op(self):
main_block = self._main_program.global_block()
# step6: loss div dp_degree
global_dp_degree = self.sharding_degree * self.dp_degree
assert int(global_dp_degree) == global_dp_degree
if global_dp_degree > 1:
......@@ -348,18 +374,67 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block._sync_with_cpp()
# TODO(wangxi): add optimize offload
# opt offload should be enable while gradient merge is enable && acc_step is quite large (e.g. >> 100)
# sync its memcpy could not be overlap with calc, otherwise it will slower down training severely.
if self.optimize_offload:
def _apply_optimize_offload_pass(self):
strategy = self.user_defined_strategy
sharding_configs = strategy.sharding_configs
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
# optimize offload should be enable while gradient merge is enable and
# acc_step is quite large (e.g. >> 100). Since its memcpy could not be
# overlap with calc, otherwise it will slower down training severely.
if sharding_configs["optimize_offload"]:
logger.info("Sharding with optimize offload !")
offload_helper = OffloadHelper()
offload_helper.offload(main_block, startup_block)
offload_helper.offload_fp32param(main_block, startup_block)
def _dump_program_for_debug(self):
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
with open("start_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(startup_block.program))
with open("main_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(main_block.program))
def minimize_impl(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
# TODO: (JZ-LIANG) support multiple comm in future
# self._nrings = self.user_defined_strategy.nccl_comm_num
self._nrings_sharding = 1
self._nrings_dp = 1
self._get_sharding_segment_strategy()
self._get_hybrid_degree()
self._get_hybrid_dp_mode()
# config sharding & dp groups
self._build_groups()
# inner optimize minimize
optimize_ops, params_grads = self._inner_opt_minimize(
loss, startup_program, parameter_list, no_grad_set)
self._init_comm()
self._apply_sharding_pass(params_grads)
self._insert_allreduce_for_pp()
self._adapt_amp_clip_without_sharding()
# loss div dp_degree
self._insert_loss_grad_scale_op()
self._apply_optimize_offload_pass()
# step6: (optional) sharding gradient merge
if self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
self._sharding_gradient_merge(main_block)
self._sharding_gradient_merge()
# # check op dependecy
# FIXME (JZ-LIANG) enable checking in future.
......@@ -367,17 +442,11 @@ class ShardingOptimizer(MetaOptimizerBase):
# check_allreduce_sum(main_block, self._shard, self.sharding_ring_id,
# self.dp_ring_id)
if self.hybrid_dp:
# NOTE(JZ-LIANG) ensure in both sharding_hybrid_dp & pp_hybrid_dp
# init param broadcast should be called after startup pruning
self._initialization_broadcast(startup_block)
# NOTE(JZ-LIANG) ensure in both sharding_hybrid_dp & pp_hybrid_dp
# init param broadcast should be called after startup pruning
self._initialization_broadcast()
with open("start_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(startup_block.program))
with open("main_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(main_block.program))
self._dump_program_for_debug()
# GPU need to wait server ready, GPU and NPU is Layered connection
if not core.is_compiled_with_npu():
......@@ -471,9 +540,6 @@ class ShardingOptimizer(MetaOptimizerBase):
def _init_pipeline_comm(self, startup_block):
# TODO (JZ-LIANG) to unify pp_rank_ and pp_rank
assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
self.pp_rank_, self.pp_rank)
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
......@@ -496,17 +562,8 @@ class ShardingOptimizer(MetaOptimizerBase):
self._init_pair_comm(pair, ring_id)
def _init_comm(self):
# config sharding & dp groups
self._build_groups()
# sync var
startup_block = self._startup_program.global_block()
self.startup_prog_sync_var = startup_block.create_var(
name="startup_prog_sync_var",
shape=[1],
dtype=core.VarDesc.VarType.INT32,
persistable=False)
# mp ring
if self.mp_degree > 1:
......@@ -1051,7 +1108,8 @@ class ShardingOptimizer(MetaOptimizerBase):
sharding: 1
pure-dp: 2
global: 3
pp: >= 20
pp: 4
pp-pair: >= 20
if one parallelism is not enable: -1
and only support parallelism hierarchy: mp --> sharding --> pp --> dp
"""
......@@ -1216,11 +1274,16 @@ class ShardingOptimizer(MetaOptimizerBase):
return
def _initialization_broadcast(self, startup_block):
def _initialization_broadcast(self):
"""
this funtion is to ensure the initialization between dp group to be
identical when hybrid-dp is used.
"""
if not self.hybrid_dp:
return
startup_block = self._startup_program.global_block()
params = []
for param in startup_block.iter_parameters():
params.append(param)
......@@ -1461,13 +1524,17 @@ class ShardingOptimizer(MetaOptimizerBase):
# lr_var = main_block.var("gradient_merge_current_step")
# paddle.static.Print(lr_var, message="in OPTIMIZE last conditional")
def _sharding_gradient_merge(self, main_block):
def _sharding_gradient_merge(self):
"""
copy all optimize ops in origin main block
remove all optimize ops in origin main block
create cond block
"""
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1:
return
main_block = self._main_program.global_block()
# copy original optimize ops to temp ops desc list
# remove them from block 0
tmp_copy_block = self._main_program._create_block()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册