From a16037972010a4b6c983cc1e181a70b1e07cbc20 Mon Sep 17 00:00:00 2001 From: WangXi Date: Tue, 10 Aug 2021 14:41:01 +0800 Subject: [PATCH] [hybrid] refine sharding code (#34678) --- .../meta_optimizers/sharding_optimizer.py | 461 ++++++++++-------- 1 file changed, 264 insertions(+), 197 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 1f1960b1700..a5df9486da4 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -84,27 +84,23 @@ class ShardingOptimizer(MetaOptimizerBase): dist_strategy.sharding = True dist_strategy.sharding_configs = {"segment_broadcast_MB": 32} - def minimize_impl(self, - loss, - startup_program=None, - parameter_list=None, - no_grad_set=None): - # TODO: (JZ-LIANG) support multiple comm in future - # self._nrings = self.user_defined_strategy.nccl_comm_num - self._nrings_sharding = 1 - self._nrings_dp = 1 + def _get_sharding_segment_strategy(self): + """ get + self._sharding_segment_strategy + 1. if by_size: self._broadcast_MB + 2. if by_anchors: self._sharding_segment_anchors + self._backward_remain_anchors + self._forward_remain_anchors + """ + strategy = self.user_defined_strategy + sharding_configs = strategy.sharding_configs + segment_strategy = str(sharding_configs["sharding_segment_strategy"]) - # segment - self._sharding_segment_strategy = str( - self.user_defined_strategy.sharding_configs[ - "sharding_segment_strategy"]) - if self._sharding_segment_strategy == "segment_broadcast_MB": - self._broadcast_MB = self.user_defined_strategy.sharding_configs[ - "segment_broadcast_MB"] + if segment_strategy == "segment_broadcast_MB": + self._broadcast_MB = sharding_configs["segment_broadcast_MB"] assert self._broadcast_MB > 0, "segment size should larger than zero !" - elif self._sharding_segment_strategy == "segment_anchors": - self._sharding_segment_anchors = self.user_defined_strategy.sharding_configs[ - "segment_anchors"] + elif segment_strategy == "segment_anchors": + self._sharding_segment_anchors = sharding_configs["segment_anchors"] assert len(self._sharding_segment_anchors ) > 0, "you should set the sharding segment anchors !" self._backward_remain_anchors = self._sharding_segment_anchors[:] @@ -112,82 +108,104 @@ class ShardingOptimizer(MetaOptimizerBase): else: raise NotImplementedError( "the sharding segment strategy [{}] is not implemented".format( - str(self._sharding_segment_strategy))) + str(segment_strategy))) + self._sharding_segment_strategy = segment_strategy + + def _get_hybrid_degree(self): + """ get + self.hybrid_dp + self.sharding_degree + self.mp_degree + self.pp_degree + self.dp_degree + """ + strategy = self.user_defined_strategy + sharding_configs = strategy.sharding_configs # parallelism - self.sharding_degree = int(self.user_defined_strategy.sharding_configs[ - "sharding_degree"]) - assert self.sharding_degree > 0, "sharding degree must be larger than zero" - self.mp_degree = int(self.user_defined_strategy.sharding_configs[ - "mp_degree"]) + sharding_degree = int(sharding_configs["sharding_degree"]) + mp_degree = int(sharding_configs["mp_degree"]) + pp_degree = int(sharding_configs["pp_degree"]) + dp_degree = int(sharding_configs['dp_degree']) + global_world_size = self.role_maker._worker_num() + + assert sharding_degree > 0, "sharding degree must be larger than zero" # pipeline setting # TODO (JZ-LIANG) should revise here for support mix parallelism with pipeline - self.pp_degree = int(self.user_defined_strategy.sharding_configs[ - "pp_degree"]) - if self.pp_degree > 1: - assert self.user_defined_strategy.pipeline == True - - self.dp_degree = int(self.user_defined_strategy.sharding_configs[ - 'dp_degree']) - assert self.role_maker._worker_num( - ) == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format( - self.role_maker._worker_num(), - self.mp_degree, - self.sharding_degree, - self.pp_degree, - self.dp_degree, ) + if pp_degree > 1: + assert strategy.pipeline is True + + assert global_world_size == mp_degree * sharding_degree * pp_degree * dp_degree, \ + "global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format( + global_world_size, mp_degree, sharding_degree, pp_degree, dp_degree) # FIXME (JZ-LIANG) deprecated hybrid_dp - if self.user_defined_strategy.sharding_configs["hybrid_dp"]: + if sharding_configs["hybrid_dp"]: logger.warning( - "[hybrid_dp] API setting is deprecated. Now when dp_degree >= 2, its will be in hybrid dp mode automatically" - ) - assert self.dp_degree >= 1 - if self.dp_degree > 1: - self.hybrid_dp = True - else: - self.hybrid_dp = False - - # NOTE (JZ-LIANG) - # there 2 kind of modes for gradient-merge and hybrid-dp in mixed parallism [sharding] and [pipeline]. - # we distinguish this two modes since the gm/hybrid-dp related allreduce should be insert in different place according different mode to have best performance: - # sharding: communication within node, and therefore should insert within backward segment to overlap with bw calc, conduct every micro step - # pipeline: communication accross nodes, and therefore should insert in update segemnt, conduct just once per global step - self.hybrid_dp_mode = None + "[hybrid_dp] API setting is deprecated. Now when " + "dp_degree >= 2, its will be in hybrid dp mode automatically") + assert dp_degree >= 1 + + self.hybrid_dp = True if dp_degree > 1 else False + self.sharding_degree = sharding_degree + self.mp_degree = mp_degree + self.pp_degree = pp_degree + self.dp_degree = dp_degree + + def _get_hybrid_dp_mode(self): + """ get + self.hybrid_dp_mode + self.gradient_merge_mode + self._gradient_merge_acc_step + self.pp_allreduce_in_optimize + """ + strategy = self.user_defined_strategy + sharding_configs = strategy.sharding_configs + + # NOTE (JZ-LIANG) + # There 2 kind of modes for gradient-merge and hybrid-dp in mixed parallelism [sharding] and [pipeline]. + # We distinguish this two modes since the gm/hybrid-dp related allreduce should be insert in different place + # according different mode to have best performance: + # sharding: communication within node, and therefore should insert within backward segment + # to overlap with bw calc, conduct every micro step. + # pipeline: communication across nodes, and therefore should insert in update segment, + # conduct just once per global step. + dp_mode = None # dp here is the pure dp as the outest parallelism if self.hybrid_dp: - assert self.dp_degree > 1, "hybrid dp is on, but dp degree is [{}]".format( - self.dp_degree) if self.pp_degree > 1: - self.hybrid_dp_mode = "pp_hybrid_dp" + dp_mode = "pp_hybrid_dp" else: - assert self.sharding_degree > 1, "by now we only support five kind of hybrid dp: sharding_hybrid_dp, mp_sharding_hybrid_dp, pp_hybrid_dp, mp_sharding_pp_hybrid_dp, sharding_pp_hybrid_dp." - self.hybrid_dp_mode = "sharding_hybrid_dp" + assert self.sharding_degree > 1, \ + "by now we only support five kind of hybrid dp: sharding_hybrid_dp, " \ + "mp_sharding_hybrid_dp, pp_hybrid_dp, mp_sharding_pp_hybrid_dp, sharding_pp_hybrid_dp." + dp_mode = "sharding_hybrid_dp" # gradient merge - self._gradient_merge_acc_step = int( - self.user_defined_strategy.sharding_configs[ - "gradient_merge_acc_step"]) - self.gradient_merge_mode = None + gm_mode = None + gm_acc_step = int(sharding_configs["gradient_merge_acc_step"]) if self.pp_degree <= 1: - self.gradient_merge_mode = "sharding_gm" + gm_mode = "sharding_gm" self._grad2merged_grad = dict() else: - self.gradient_merge_mode = "pp_gm" - self._gradient_merge_acc_step = self.user_defined_strategy.pipeline_configs[ - 'accumulate_steps'] - if self._gradient_merge_acc_step > 1: + gm_mode = "pp_gm" + gm_acc_step = strategy.pipeline_configs['accumulate_steps'] + if gm_acc_step > 1: logger.info("Gradient merge in [{}], acc step = [{}]".format( - self.gradient_merge_mode, self._gradient_merge_acc_step)) + gm_mode, gm_acc_step)) - # optimize offload - self.optimize_offload = self.user_defined_strategy.sharding_configs[ - "optimize_offload"] + self.hybrid_dp_mode = dp_mode + self.gradient_merge_mode = gm_mode + self._gradient_merge_acc_step = gm_acc_step # this feature is design for ascend, and should NOT be used in GPU training - self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[ + self.pp_allreduce_in_optimize = sharding_configs[ "pp_allreduce_in_optimize"] + def _inner_opt_minimize(self, loss, startup_program, parameter_list, + no_grad_set): + pipeline_configs = self.user_defined_strategy.pipeline_configs + if self.inner_opt is None: raise ValueError( "self.inner_opt of ShardingOptimizer should not be None.") @@ -195,32 +213,29 @@ class ShardingOptimizer(MetaOptimizerBase): if self.pp_degree > 1: pp_optimizer = fluid.optimizer.PipelineOptimizer( self.inner_opt, self._gradient_merge_acc_step) - - strategy = self.user_defined_strategy - self.schedule_mode = strategy.pipeline_configs['schedule_mode'] - self.pp_rank_ = self.role_maker._worker_index() // ( - self.sharding_degree * self.mp_degree) % self.pp_degree - - pipeline_opt = dict() - pipeline_opt['schedule_mode'] = self.schedule_mode - pipeline_opt['micro_batch_size'] = strategy.pipeline_configs[ - 'micro_batch_size'] - pipeline_opt['local_rank'] = self.pp_rank_ - pipeline_opt['global_rank'] = self.role_maker._worker_index() - pipeline_opt['use_sharding'] = True - # TODO (JZ-LIANG) should revise here for support mix parallelism with pipeline - pipeline_opt['ring_id'] = 20 - pipeline_opt['global_ring_id'] = 3 - pipeline_opt['mp_degree'] = self.mp_degree - pipeline_opt['mp_rank'] = self.role_maker._worker_index( - ) % self.mp_degree - + self._pp_optimizer = pp_optimizer + + global_rank = self.role_maker._worker_index() + schedule_mode = pipeline_configs['schedule_mode'] + + pipeline_opt = { + 'schedule_mode': schedule_mode, + 'micro_batch_size': pipeline_configs['micro_batch_size'], + 'local_rank': self.pp_rank, + 'global_rank': global_rank, + 'use_sharding': True, + # TODO (JZ-LIANG) should revise here for support mix parallelism with pipeline + 'ring_id': 20, + 'global_ring_id': 3, + 'mp_degree': self.mp_degree, + 'mp_rank': global_rank % self.mp_degree, + } main_program = loss.block.program main_program._pipeline_opt = pipeline_opt optimize_ops, params_grads, program_list, self.pipeline_pair, self.pp_ring_map = pp_optimizer.minimize( loss, startup_program, parameter_list, no_grad_set) - self.pp_degree = len(program_list) + assert self.pp_degree == len(program_list) else: optimize_ops, params_grads = self.inner_opt.minimize( loss, startup_program, parameter_list, no_grad_set) @@ -230,9 +245,8 @@ class ShardingOptimizer(MetaOptimizerBase): if self.pp_degree > 1: startup_program = startup_program._pipeline_opt['startup_program'] - #main_program = main_program._pipeline_opt['section_program']['program'] - print("pp_rank:", self.pp_rank_) - main_program = program_list[self.pp_rank_] + print("pp_rank:", self.pp_rank) + main_program = program_list[self.pp_rank] with open("main_%d" % self.role_maker._worker_index(), 'w') as f: f.writelines(str(main_program)) main_block = main_program.global_block() @@ -241,7 +255,6 @@ class ShardingOptimizer(MetaOptimizerBase): if main_block.has_var(param.name): new_params_grads.append((param, grad)) params_grads = new_params_grads - else: main_block = loss.block @@ -254,93 +267,106 @@ class ShardingOptimizer(MetaOptimizerBase): with open("main_%d" % self.role_maker._worker_index(), 'w') as f: f.writelines(str(main_program)) - # step0: _init_comm - self._init_comm() + return optimize_ops, params_grads - if self.sharding_degree > 1: + def _apply_sharding_pass(self, params_grads): + if self.sharding_degree == 1: return + + main_block = self._main_program.global_block() + startup_block = self._startup_program.global_block() - # step1: build shard - self._build_shard(params_grads) + # step1: build shard + self._build_shard(params_grads) - # step2: split_program - self._split_program(main_block) + # step2: split_program + self._split_program(main_block) - # step3: add broadcast and reduce ops - self._add_broadcast_allreduce(main_block) - main_block._sync_with_cpp() - startup_block._sync_with_cpp() + # step3: add broadcast and reduce ops + self._add_broadcast_allreduce(main_block) + main_block._sync_with_cpp() + startup_block._sync_with_cpp() - main_block._sync_with_cpp() + # step4: remove unneeded ops and vars from block + self._prune_main_program(main_block) + self._prune_startup_program(startup_block) - # step4: remove unneeded ops and vars from block - self._prune_main_program(main_block) - self._prune_startup_program(startup_block) + def _insert_allreduce_for_pp(self): + if self.pp_degree == 1: return - if self.pp_degree > 1: - # sharding-pp related logic - # pp_optimizer._rename_gradient_var_name(main_block) - # crop ops - if self.sharding_degree > 1: - for idx, op in reversed(list(enumerate(main_block.ops))): - if is_update_op(op): - op_role_var = op.attr('op_role_var') - param_name = op_role_var[0] - if not self._shard.has_param(param_name): - main_block._remove_op(idx) - - for idx, op in reversed(list(enumerate(main_block.ops))): - if op.type != 'cast': continue - in_name = op.input_arg_names[0] - if in_name not in self._params: continue - #if self._shard.has_param(param_name): continue - if in_name not in main_block.vars: + strategy = self.user_defined_strategy + main_block = self._main_program.global_block() + startup_block = self._startup_program.global_block() + + # sharding-pp related logic + # pp_optimizer._rename_gradient_var_name(main_block) + # crop ops + if self.sharding_degree > 1: + for idx, op in reversed(list(enumerate(main_block.ops))): + if is_update_op(op): + op_role_var = op.attr('op_role_var') + param_name = op_role_var[0] + if not self._shard.has_param(param_name): main_block._remove_op(idx) - accumulated_grad_names = pp_optimizer._accumulate_gradients( - main_block) - # accumulated_grad_names = sorted(accumulated_grad_names) - if self.pp_allreduce_in_optimize: - print("persistable FP32 grad: ") - print(accumulated_grad_names) - first_optimize_op_index = get_first_check_finite_and_unscale_op_idx( - main_block, raise_error=self.user_defined_strategy.amp) - insert_reduce_ops( + for idx, op in reversed(list(enumerate(main_block.ops))): + if op.type != 'cast': continue + in_name = op.input_arg_names[0] + if in_name not in self._params: continue + #if self._shard.has_param(param_name): continue + if in_name not in main_block.vars: + main_block._remove_op(idx) + + accumulated_grad_names = self._pp_optimizer._accumulate_gradients( + main_block) + # accumulated_grad_names = sorted(accumulated_grad_names) + if self.pp_allreduce_in_optimize: + print("persistable FP32 grad: ") + print(accumulated_grad_names) + first_optimize_op_index = get_first_check_finite_and_unscale_op_idx( + main_block, raise_error=strategy.amp) + insert_reduce_ops( + main_block, + first_optimize_op_index, + self.sharding_ring_id, + accumulated_grad_names, + self._shard, + core.op_proto_and_checker_maker.OpRole.Optimize, + use_calc_stream=True) + if self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp": + first_optimize_op_index = get_first_check_finite_and_unscale_op_idx( + main_block, raise_error=strategy.amp) + if first_optimize_op_index >= 0: + insert_allreduce_ops( main_block, first_optimize_op_index, - self.sharding_ring_id, + self.dp_ring_id, accumulated_grad_names, - self._shard, core.op_proto_and_checker_maker.OpRole.Optimize, - use_calc_stream=True) - if self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp": - first_optimize_op_index = get_first_check_finite_and_unscale_op_idx( - main_block, raise_error=self.user_defined_strategy.amp) - if first_optimize_op_index >= 0: - insert_allreduce_ops( - main_block, - first_optimize_op_index, - self.dp_ring_id, - accumulated_grad_names, - core.op_proto_and_checker_maker.OpRole.Optimize, - use_calc_stream=True, - user_defined_strategy=self.user_defined_strategy) + use_calc_stream=True, + user_defined_strategy=strategy) + def _adapt_amp_clip_without_sharding(self): + if self.sharding_degree > 1: return # if not use sharding, adapt amp/clip, for remain parallelism. # cast --> amp --> clip --> opt - if self.sharding_degree <= 1: - # FIXME(wangxi): mp should prune duplicated param_grads when calc - # amp inf_var & clip global_norm_var - # amp - FP16Utils.sync_amp_check_nan_inf( - main_block, [self.mp_ring_id, self.pp_ring_id]) + main_block = self._main_program.global_block() + startup_block = self._startup_program.global_block() + + # FIXME(wangxi): mp should prune duplicated param_grads when calc + # amp inf_var & clip global_norm_var - # clip - gradientclip_helper = GradientClipHelper(None) - gradientclip_helper.sync_global_norm( - main_block, [self.mp_ring_id, self.pp_ring_id]) + FP16Utils.sync_amp_check_nan_inf(main_block, + [self.mp_ring_id, self.pp_ring_id]) - # step6: loss div dp_degree + gradientclip_helper = GradientClipHelper(None) + gradientclip_helper.sync_global_norm( + main_block, [self.mp_ring_id, self.pp_ring_id]) + + def _insert_loss_grad_scale_op(self): + main_block = self._main_program.global_block() + + # step6: loss div dp_degree global_dp_degree = self.sharding_degree * self.dp_degree assert int(global_dp_degree) == global_dp_degree if global_dp_degree > 1: @@ -348,18 +374,67 @@ class ShardingOptimizer(MetaOptimizerBase): main_block._sync_with_cpp() - # TODO(wangxi): add optimize offload - # opt offload should be enable while gradient merge is enable && acc_step is quite large (e.g. >> 100) - # sync its memcpy could not be overlap with calc, otherwise it will slower down training severely. - if self.optimize_offload: + def _apply_optimize_offload_pass(self): + strategy = self.user_defined_strategy + sharding_configs = strategy.sharding_configs + main_block = self._main_program.global_block() + startup_block = self._startup_program.global_block() + + # optimize offload should be enable while gradient merge is enable and + # acc_step is quite large (e.g. >> 100). Since its memcpy could not be + # overlap with calc, otherwise it will slower down training severely. + if sharding_configs["optimize_offload"]: logger.info("Sharding with optimize offload !") offload_helper = OffloadHelper() offload_helper.offload(main_block, startup_block) offload_helper.offload_fp32param(main_block, startup_block) + def _dump_program_for_debug(self): + main_block = self._main_program.global_block() + startup_block = self._startup_program.global_block() + with open("start_sharding_%d" % self.role_maker._worker_index(), + 'w') as f: + f.writelines(str(startup_block.program)) + with open("main_sharding_%d" % self.role_maker._worker_index(), + 'w') as f: + f.writelines(str(main_block.program)) + + def minimize_impl(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + # TODO: (JZ-LIANG) support multiple comm in future + # self._nrings = self.user_defined_strategy.nccl_comm_num + self._nrings_sharding = 1 + self._nrings_dp = 1 + + self._get_sharding_segment_strategy() + self._get_hybrid_degree() + self._get_hybrid_dp_mode() + + # config sharding & dp groups + self._build_groups() + + # inner optimize minimize + optimize_ops, params_grads = self._inner_opt_minimize( + loss, startup_program, parameter_list, no_grad_set) + + self._init_comm() + + self._apply_sharding_pass(params_grads) + + self._insert_allreduce_for_pp() + + self._adapt_amp_clip_without_sharding() + + # loss div dp_degree + self._insert_loss_grad_scale_op() + + self._apply_optimize_offload_pass() + # step6: (optional) sharding gradient merge - if self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: - self._sharding_gradient_merge(main_block) + self._sharding_gradient_merge() # # check op dependecy # FIXME (JZ-LIANG) enable checking in future. @@ -367,17 +442,11 @@ class ShardingOptimizer(MetaOptimizerBase): # check_allreduce_sum(main_block, self._shard, self.sharding_ring_id, # self.dp_ring_id) - if self.hybrid_dp: - # NOTE(JZ-LIANG) ensure in both sharding_hybrid_dp & pp_hybrid_dp - # init param broadcast should be called after startup pruning - self._initialization_broadcast(startup_block) + # NOTE(JZ-LIANG) ensure in both sharding_hybrid_dp & pp_hybrid_dp + # init param broadcast should be called after startup pruning + self._initialization_broadcast() - with open("start_sharding_%d" % self.role_maker._worker_index(), - 'w') as f: - f.writelines(str(startup_block.program)) - with open("main_sharding_%d" % self.role_maker._worker_index(), - 'w') as f: - f.writelines(str(main_block.program)) + self._dump_program_for_debug() # GPU need to wait server ready, GPU and NPU is Layered connection if not core.is_compiled_with_npu(): @@ -471,9 +540,6 @@ class ShardingOptimizer(MetaOptimizerBase): def _init_pipeline_comm(self, startup_block): # TODO (JZ-LIANG) to unify pp_rank_ and pp_rank - assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format( - self.pp_rank_, self.pp_rank) - self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, @@ -496,17 +562,8 @@ class ShardingOptimizer(MetaOptimizerBase): self._init_pair_comm(pair, ring_id) def _init_comm(self): - - # config sharding & dp groups - self._build_groups() - # sync var startup_block = self._startup_program.global_block() - self.startup_prog_sync_var = startup_block.create_var( - name="startup_prog_sync_var", - shape=[1], - dtype=core.VarDesc.VarType.INT32, - persistable=False) # mp ring if self.mp_degree > 1: @@ -1051,7 +1108,8 @@ class ShardingOptimizer(MetaOptimizerBase): sharding: 1 pure-dp: 2 global: 3 - pp: >= 20 + pp: 4 + pp-pair: >= 20 if one parallelism is not enable: -1 and only support parallelism hierarchy: mp --> sharding --> pp --> dp """ @@ -1216,11 +1274,16 @@ class ShardingOptimizer(MetaOptimizerBase): return - def _initialization_broadcast(self, startup_block): + def _initialization_broadcast(self): """ this funtion is to ensure the initialization between dp group to be identical when hybrid-dp is used. """ + if not self.hybrid_dp: + return + + startup_block = self._startup_program.global_block() + params = [] for param in startup_block.iter_parameters(): params.append(param) @@ -1461,13 +1524,17 @@ class ShardingOptimizer(MetaOptimizerBase): # lr_var = main_block.var("gradient_merge_current_step") # paddle.static.Print(lr_var, message="in OPTIMIZE last conditional") - def _sharding_gradient_merge(self, main_block): + def _sharding_gradient_merge(self): """ copy all optimize ops in origin main block remove all optimize ops in origin main block create cond block """ + if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: + return + + main_block = self._main_program.global_block() # copy original optimize ops to temp ops desc list # remove them from block 0 tmp_copy_block = self._main_program._create_block() -- GitLab