diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py index 70753b59ccc318a25661e084bd305d7d76b0e2a6..9748bec3454d5368972baf42cbb8448869c8315c 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py @@ -126,6 +126,9 @@ class ProgramDeps(object): def should_remove_op(self, op_idx): op = self._block.ops[op_idx] + # remove check_finite_and_unscale op if its input 'X' is empty + if op.type == 'check_finite_and_unscale' and len(op.input('X')) == 0: + return True for output_name in op.desc.output_arg_names(): if output_name not in self._should_removed_var: return False diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index ad1cd4f60826bbf434294114d1982cb4beb3f00a..1691bf7387ad18b303c6db798d1ab2873362b101 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -28,21 +28,24 @@ def check_broadcast(block): if the broadcasted var has a fill_constant op, the fill_constant op should stay forward before the broadcast op, and before a sync_calc op. Otherwise, raise error. + + should ignore and skip broadcast_op of inner_parallelism (e.g. Megatron) """ broadcast_vars = {} for idx, op in enumerate(block.ops): if op.type == "c_broadcast": - var_name = op.desc.input_arg_names()[0] - if "@BroadCast" in var_name: - if var_name in broadcast_vars: - raise ValueError("var_name areadly exist: {}" - "the old pos is {}, the new pos is {}". - format(var_name, broadcast_vars[var_name][ - "broadcast_pos"], idx)) - broadcast_vars[var_name] = { - "fill_constant_pos": -1, - "broadcast_pos": idx, - } + if op.all_attrs()["use_calc_stream"] == False: + var_name = op.desc.input_arg_names()[0] + if "@BroadCast" in var_name: + if var_name in broadcast_vars: + raise ValueError("var_name areadly exist: {}" + "the old pos is {}, the new pos is {}". + format(var_name, broadcast_vars[ + var_name]["broadcast_pos"], idx)) + broadcast_vars[var_name] = { + "fill_constant_pos": -1, + "broadcast_pos": idx, + } for idx, op in enumerate(block.ops): if op.type == "fill_constant": @@ -61,14 +64,15 @@ def check_broadcast(block): last_sync_calc_op_idx = idx continue if op.type == "c_broadcast": - var_name = op.desc.input_arg_names()[0] - if "@BroadCast" in var_name: - if broadcast_vars[var_name]["fill_constant_pos"] != -1: - assert (last_sync_calc_op_idx != -1) - assert (broadcast_vars[var_name]["fill_constant_pos"] < - last_sync_calc_op_idx) - assert (last_sync_calc_op_idx < idx) - continue + if op.all_attrs()["use_calc_stream"] == False: + var_name = op.desc.input_arg_names()[0] + if "@BroadCast" in var_name: + if broadcast_vars[var_name]["fill_constant_pos"] != -1: + assert (last_sync_calc_op_idx != -1) + assert (broadcast_vars[var_name]["fill_constant_pos"] < + last_sync_calc_op_idx) + assert (last_sync_calc_op_idx < idx) + continue for input_name in op.desc.input_arg_names(): if input_name in broadcast_vars: assert (broadcast_vars[input_name]["broadcast_pos"] != -1) @@ -78,7 +82,7 @@ def check_broadcast(block): return -def check_allreduce_sum(block, shard, dp_ring_id=-1): +def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): """ the op order should be: grad: @@ -89,32 +93,36 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1): - 4: allreuce_sum_dp (dp_grads) - 5: sync_comm (dp_grads) - 6: op that use Var (dp_grads & sum) + + should ignore and skip allreduce_op of inner_parallelism (e.g. Megatron) """ vars_status = {} dp_grads_status = {} idx_last_grad_allreduce = -1 idx_amp_allreduce = -1 idx_gradient_clip_allreduce = -1 + for idx, op in enumerate(block.ops): if op.type == "c_allreduce_sum": - ring_id = op.desc.attr("ring_id") - var_name = op.desc.input_arg_names()[0] - param = var_name.split("@")[0] + if op.all_attrs()["use_calc_stream"] == False: + ring_id = op.desc.attr("ring_id") + var_name = op.desc.input_arg_names()[0] + param = var_name.split("@")[0] - assert 'sum' in var_name or ("@GRAD" in var_name) - if 'sum' in var_name or (not shard.has_param(param)): - vars_status[var_name] = -1 - else: - dp_grads_status[var_name] = -1 + assert 'sum' in var_name or ("@GRAD" in var_name) + if 'sum' in var_name or (not shard.has_param(param)): + vars_status[var_name] = -1 + else: + dp_grads_status[var_name] = -1 - if ring_id != 0: - assert shard.has_param(param) - assert ring_id == dp_ring_id + if ring_id != sharding_ring_id: + assert shard.has_param(param) + assert ring_id == dp_ring_id - if "sum" in var_name: - idx_amp_allreduce = idx - elif "@GRAD": - idx_last_grad_allreduce = idx + if "sum" in var_name: + idx_amp_allreduce = idx + elif "@GRAD": + idx_last_grad_allreduce = idx if op.type == "c_allreduce_max": idx_gradient_clip_allreduce = idx @@ -130,36 +138,38 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1): dp_grads_status[var_name] = 1 elif op.type == "c_allreduce_sum": - var_name = op.desc.input_arg_names()[0] - ring_id = op.desc.attr("ring_id") - if ring_id == 0: - if var_name in vars_status: - _status = vars_status[var_name] - else: - _status = dp_grads_status[var_name] - if _status == -1: - raise ValueError("{} is not generated, but you are" - "trying to all-reduce it".format(var_name)) - if _status == 0: - raise ValueError("There should be a sync_calc op " - "after generate Var: {} and before the" - "c_allreduce_sum op".format(var_name)) - assert (_status == 1) - if var_name in vars_status: - vars_status[var_name] = 2 + if op.all_attrs()["use_calc_stream"] == False: + var_name = op.desc.input_arg_names()[0] + ring_id = op.desc.attr("ring_id") + if ring_id == sharding_ring_id: + if var_name in vars_status: + _status = vars_status[var_name] + else: + _status = dp_grads_status[var_name] + if _status == -1: + raise ValueError("{} is not generated, but you are" + "trying to all-reduce it".format( + var_name)) + if _status == 0: + raise ValueError("There should be a sync_calc op " + "after generate Var: {} and before the" + "c_allreduce_sum op".format(var_name)) + assert (_status == 1) + if var_name in vars_status: + vars_status[var_name] = 2 + else: + dp_grads_status[var_name] = 2 else: - dp_grads_status[var_name] = 2 - else: - assert ring_id == dp_ring_id - param = var_name.split("@")[0] - assert shard.has_param(param) - assert dp_grads_status[var_name] == 3 - dp_grads_status[var_name] = 4 + assert ring_id == dp_ring_id + param = var_name.split("@")[0] + assert shard.has_param(param) + assert dp_grads_status[var_name] == 3 + dp_grads_status[var_name] = 4 elif op.type == "c_sync_comm_stream": var_name = op.desc.input_arg_names()[0] ring_id = op.desc.attr("ring_id") - if ring_id == 0: + if ring_id == sharding_ring_id: for var_name in op.desc.input_arg_names(): if var_name in vars_status: assert vars_status[var_name] == 2 @@ -217,9 +227,14 @@ def get_valid_op_role(block, insert_idx): return OpRole.Forward or OpRole.Backward """ op_role = block.ops[insert_idx].attr('op_role') - if (insert_idx >= len(block.ops)) or ( - op_role in [int(OpRole.Backward), int(OpRole.Optimize)]): - return OpRole.Backward + #if (insert_idx >= len(block.ops)) or ( + # op_role in [int(OpRole.Backward), int(OpRole.Optimize)]): + # return OpRole.Backward + #if op_role in [int(OpRole.Forward), int(OpRole.Loss)]: + # return OpRole.Forward + if insert_idx >= len(block.ops): return OpRole.Optimize + if op_role == int(OpRole.Backward): return OpRole.Backward + if op_role == int(OpRole.Optimize): return OpRole.Optimize if op_role in [int(OpRole.Forward), int(OpRole.Loss)]: return OpRole.Forward @@ -428,7 +443,7 @@ def comm_analyse(main_program): count)) -def add_sync_comm(program, dist_strategy): +def add_sync_comm(program, nccl_ids): """ When clone a test prog by clone from the sharding main prog, part of the sync_comm op maybe be pruned by mistake, this function @@ -438,6 +453,9 @@ def add_sync_comm(program, dist_strategy): #NOTE (liangjianzhong): only support one comm stream by now, use more than one # comm streams will cause error. should be revise in future. + assert isinstance( + nccl_ids, list + ), "the second argument of this function should be a list of nccl_ids" block = program.global_block() not_sync_vars = set([]) for op in block.ops: @@ -448,7 +466,7 @@ def add_sync_comm(program, dist_strategy): for input_name in op.desc.input_arg_names(): not_sync_vars.remove(input_name) if not_sync_vars: - for nccl_id in range(dist_strategy.nccl_comm_num): + for nccl_id in nccl_ids: block.append_op( type='c_sync_comm_stream', inputs={'X': list(not_sync_vars)}, @@ -467,6 +485,9 @@ def save_persistables(exe, dirname, main_program, filename=None): This function handles the model saving for sharding training. """ + if main_program._pipeline_opt: + main_program = main_program._pipeline_opt['section_program']['program'] + def is_opt_vars(var): # NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer # now only Momentum and adam are compatible with sharding diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 7fed227cd9936ccac730b58959a2ce8bed51e4ef..35bfd6a1b0c20f0075caa819f84e85bc02de6021 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -16,7 +16,7 @@ from paddle.fluid import unique_name, core import paddle.fluid as fluid from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper -from paddle.distributed.fleet.meta_optimizers.common import is_backward_op +from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils @@ -39,6 +39,7 @@ class ShardingOptimizer(MetaOptimizerBase): "AMPOptimizer", "LarsOptimizer", "LambOptimizer", + "ModelParallelOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] self._main_program = None @@ -51,6 +52,10 @@ class ShardingOptimizer(MetaOptimizerBase): self._reduced_grads_to_param = {} self._shard = Shard() + # use sharding as outer parallelism (e.g. inner:Megatron & outer sharding) + self._as_outer_parallelism = False + self._inner_parallelism_size = None + def _can_apply(self): if not self.role_maker._is_collective: return False @@ -79,20 +84,61 @@ class ShardingOptimizer(MetaOptimizerBase): "fuse_broadcast_MB"] self.hybrid_dp = self.user_defined_strategy.sharding_configs[ "hybrid_dp"] + self._as_outer_parallelism = self.user_defined_strategy.sharding_configs[ + "as_outer_parallelism"] + self._inner_parallelism_size = int( + self.user_defined_strategy.sharding_configs[ + "inner_parallelism_size"]) + self.use_pipeline = self.user_defined_strategy.sharding_configs[ + "use_pipeline"] if self.inner_opt is None: raise ValueError( "self.inner_opt of ShardingOptimizer should not be None.") - optimize_ops, params_grads = self.inner_opt.minimize( - loss, startup_program, parameter_list, no_grad_set) + if self.use_pipeline: + pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt) + main_program = loss.block.program + main_program._pipeline_opt = dict() + pp_rank = self.role_maker._worker_index( + ) // self.user_defined_strategy.sharding_configs[ + 'sharding_group_size'] + main_program._pipeline_opt['local_rank'] = pp_rank + main_program._pipeline_opt[ + 'global_rank'] = self.role_maker._worker_index() + main_program._pipeline_opt['use_sharding'] = True + main_program._pipeline_opt['ring_id'] = 1 + optimize_ops, params_grads, program_list = pp_optimizer.minimize( + loss, startup_program, parameter_list, no_grad_set) + self.pipeline_nodes = len(program_list) + else: + optimize_ops, params_grads = self.inner_opt.minimize( + loss, startup_program, parameter_list, no_grad_set) if startup_program is None: startup_program = default_startup_program() - main_block = loss.block + if self.use_pipeline: + startup_program = startup_program._pipeline_opt['startup_program'] + #main_program = main_program._pipeline_opt['section_program']['program'] + print("pp_rank:", pp_rank) + main_program = program_list[pp_rank]['program'] + with open("main_%d" % self.role_maker._worker_index(), 'w') as f: + f.writelines(str(main_program)) + main_block = main_program.global_block() + new_params_grads = [] + for param, grad in params_grads: + if main_block.has_var(param.name): + new_params_grads.append((param, grad)) + params_grads = new_params_grads + + else: + main_block = loss.block startup_block = startup_program.global_block() self._main_program = main_block.program self._startup_program = startup_program + if self.use_pipeline: + pp_optimizer._rename_gradient_var_name(main_block) + # step1: set_up self._set_up(params_grads) @@ -105,17 +151,200 @@ class ShardingOptimizer(MetaOptimizerBase): startup_block._sync_with_cpp() # step4: insert reduce_sum for grad - insert_scale_loss_grad_ops( - main_block, scale=1.0 / self.role_maker._worker_num()) + # grad_scale_coeff = self.role_maker._worker_num() + # if self._as_outer_parallelism: + # grad_scale_coeff = grad_scale_coeff / self._inner_parallelism_size + # insert_scale_loss_grad_ops(main_block, scale=1.0 / grad_scale_coeff) + sharding_group_size = self.user_defined_strategy.sharding_configs[ + 'sharding_group_size'] + insert_scale_loss_grad_ops(main_block, scale=1.0 / sharding_group_size) main_block._sync_with_cpp() # step5: remove unneeded ops and vars from block self._prune_main_program(main_block) self._prune_startup_program(startup_block) + if self.hybrid_dp: + self._initialization_broadcast(startup_program) + + if self.use_pipeline: + # crop ops + for idx, op in reversed(list(enumerate(main_block.ops))): + # if op.type == 'fill_constant' and int(op.attr('op_role')) == 16: + # out_name = op.output_arg_names[0] + # if not 'GRAD' in out_name: continue + # param_name = out_name.strip("@GRAD") + # #if main_block.has_var(out_name): continue + # if self._shard.has_param(param_name): continue + # main_block._remove_op(idx) + if is_update_op(op): + op_role_var = op.attr('op_role_var') + param_name = op_role_var[0] + if not self._shard.has_param(param_name): + main_block._remove_op(idx) + + param_list = [] + for param_name, grad_name in params_grads: + if self._shard.has_param(param_name): + param_list.append(param_name) + #pp_optimizer._clear_gradients(main_block, param_list) + pp_optimizer._accumulate_gradients(main_block) + #if not self._shard.has_param(param_name): continue + ##if not main_block.has_var(grad_name): continue + #assert main_block.has_var(grad_name) + #grad_var = main_block.vars[grad_name] + #grad_var.persistable = True + #main_block._insert_op( + # index=0, + # type='fill_constant', + # inputs={}, + # outputs={'Out': [grad_var]}, + # attrs={ + # 'shape': grad_var.shape, + # 'dtype': grad_var.dtype, + # 'value': float(0), + # #self._op_device_key: device, + # # a trick to run this op once per mini-batch + # 'op_role': core.op_proto_and_checker_maker.OpRole.LRSched, + # }) + + #def _create_var(block, ref_var, name): + # """ + # Create a new var for block, which has the same type, + # shape and dtype as ref_var, then rename it with the + # name `name`. + # """ + # new_var = block.create_var( + # name=name, + # shape=ref_var.shape, + # dtype=ref_var.dtype, + # type=ref_var.type, + # lod_level=ref_var.lod_level, + # persistable=ref_var.persistable, + # is_data=ref_var.is_data, + # need_check_feed=ref_var.desc.need_check_feed()) + # new_var.stop_gradient = ref_var.stop_gradient + # return new_var + + #def _rename_arg(op, old_name, new_name): + # op_desc = op.desc + # if isinstance(op_desc, tuple): + # op_desc = op_desc[0] + # op_desc._rename_input(old_name, new_name) + # op_desc._rename_output(old_name, new_name) + + #print("params_grads:", params_grads) + #for param_name, grad_name in params_grads: + # if not self._shard.has_param(param_name): continue + # #if not main_block.has_var(grad_name): continue + # assert main_block.has_var(grad_name) + # use_fp16 = False + # fp16_grad_name = param_name + '.cast_fp16@GRAD' + # if main_block.has_var(grad_name): + # fp16_grad_var = main_block.vars[fp16_grad_name] + # use_fp16 = True + # grad_var = main_block.vars[grad_name] + # if use_fp16: + # cast_grad_var_name = paddle.fluid.unique_name.generate( + # grad_name) + # cast_var = _create_var(main_block, fp16_grad_var, + # cast_grad_var_name) + # cast_var.persistable = False + # main_block.append_op( + # #index=offset + 1, + # type='cast', + # inputs={'X': grad_var}, + # outputs={'Out': cast_var}, + # attrs={ + # 'in_dtype': grad_var.dtype, + # 'out_dtype': cast_var.dtype, + # 'op_role': + # core.op_proto_and_checker_maker.OpRole.Backward, + # }) + # #offset += 1 + # main_block.append_op( + # #index=offset + 1, + # type='sum', + # inputs={'X': [fp16_grad_var, cast_var]}, + # outputs={'Out': fp16_grad_var}, + # attrs={ + # 'op_role': + # core.op_proto_and_checker_maker.OpRole.Backward, + # 'op_role_var': op_role_var + # }) + + # for index, op in reversed(tuple(enumerate(list(main_block.ops)))): + # offset = index + # if is_backward_op(op) and ( + # 'op_role_var' in op.attr_names): + # op_role_var = op.all_attrs()['op_role_var'] + + # if len(op_role_var) == 0: + # continue + # assert len(op_role_var) % 2 == 0 + # offset = index + # for i in range(0, len(op_role_var), 2): + # grad_name = op_role_var[i + 1] + # if not main_block.has_var(grad_name): continue + # grad_var = main_block.vars[grad_name] + # if not 'cast_fp16' in grad_name: + # new_grad_var_name = paddle.fluid.unique_name.generate(grad_name) + # new_var = _create_var(main_block, grad_var, + # new_grad_var_name) + # new_var.persistable = False + # _rename_arg(op, grad_name, new_grad_var_name) + # main_block._insert_op( + # index=offset + 1, + # type='sum', + # inputs={'X': [grad_var, new_var]}, + # outputs={'Out': grad_var}, + # attrs={ + # 'op_role': core.op_proto_and_checker_maker.OpRole.Backward, + # 'op_role_var': op_role_var + # }) + # offset += 1 + # if 'cast_fp16' in grad_name: + # param_name = op_role_var[i] + # fp32_grad_var_name = param_name + "@GRAD" + # fp32_grad_var = main_block.vars[grad_name] + # cast_grad_var_name = paddle.fluid.unique_name.generate( + # fp32_grad_var_name) + # cast_var = _create_var(main_block, grad_var, + # cast_grad_var_name) + # cast_var.persistable = False + # main_block._insert_op( + # index=offset + 1, + # type='cast', + # inputs={'X': fp32_grad_var}, + # outputs={'Out': cast_var}, + # attrs={ + # 'in_dtype': fp32_grad_var.dtype, + # 'out_dtype': cast_var.dtype, + # 'op_role': core.op_proto_and_checker_maker.OpRole.Backward, + # # self._op_role_var_key: op_role_var + # }) + # offset += 1 + # main_block._insert_op( + # index=offset + 1, + # type='sum', + # inputs={'X': [grad_var, cast_var]}, + # outputs={'Out': grad_var}, + # attrs={ + # 'op_role': core.op_proto_and_checker_maker.OpRole.Backward, + # 'op_role_var': op_role_var}) + main_block._sync_with_cpp() + + with open("start_sharding_%d" % self.role_maker._worker_index(), + 'w') as f: + f.writelines(str(startup_block.program)) + with open("main_sharding_%d" % self.role_maker._worker_index(), + 'w') as f: + f.writelines(str(main_block.program)) # check op dependecy check_broadcast(main_block) - check_allreduce_sum(main_block, self._shard, self.dp_ring_id) + check_allreduce_sum(main_block, self._shard, self.sharding_ring_id, + self.dp_ring_id) + #check_allreduce_sum(main_block, self._shard, self.dp_ring_id) self._wait() return optimize_ops, params_grads @@ -134,11 +363,23 @@ class ShardingOptimizer(MetaOptimizerBase): self._startup_program, self.current_endpoint, self.sharding_group_endpoints, self.sharding_rank, self.sharding_ring_id, True) + + # inner & outer model parallelism + if self._as_outer_parallelism: + self._collective_helper._init_communicator( + self._startup_program, self.current_endpoint, + self.mp_group_endpoints, self.mp_rank, self.mp_group_id, True) + # dp if self.hybrid_dp: self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, True) + # pp + if self.use_pipeline: + self._collective_helper._init_communicator( + self._startup_program, self.current_endpoint, + self.pp_group_endpoints, self.pp_rank, self.pp_ring_id, True) startup_block = self._startup_program.global_block() startup_block._sync_with_cpp() @@ -205,8 +446,8 @@ class ShardingOptimizer(MetaOptimizerBase): for i in range(0, len(op_role_var), 2): param, reduced_grad = op_role_var[i], op_role_var[i + 1] segment._allreduce_vars.append(reduced_grad) - assert ( - reduced_grad not in self._reduced_grads_to_param) + #assert ( + # reduced_grad not in self._reduced_grads_to_param) self._reduced_grads_to_param[reduced_grad] = param # find cast op @@ -234,9 +475,14 @@ class ShardingOptimizer(MetaOptimizerBase): """ weightdecay_helper = WeightDecayHelper() weightdecay_helper.prune_weight_decay(block, self._shard) + # NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism + # group. and each Data Parallelism group should have its own sync of FoundInfinite + Model_Paramllelism_ring_id = self.sharding_ring_id + if self._as_outer_parallelism: + Model_Paramllelism_ring_id = self.mp_group_id FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param, - self.sharding_ring_id) - gradientclip_helper = GradientClipHelper(self.sharding_ring_id) + Model_Paramllelism_ring_id) + gradientclip_helper = GradientClipHelper(Model_Paramllelism_ring_id) gradientclip_helper.prune_gradient_clip(block, self._shard) # build prog deps @@ -264,8 +510,13 @@ class ShardingOptimizer(MetaOptimizerBase): # Prune for idx, op in reversed(list(enumerate(block.ops))): if op.type in [ - "c_allreduce_sum", "c_sync_comm_stream", - "c_calc_comm_stream", "c_gen_nccl_id", "c_comm_init, c_comm_init_hcom" + "c_allreduce_sum", + "c_sync_comm_stream", + "c_calc_comm_stream", + "c_gen_nccl_id", + "c_comm_init", + 'send_v2', + 'recv_v2', ]: pass elif op.type == "conditional_block": @@ -303,6 +554,14 @@ class ShardingOptimizer(MetaOptimizerBase): program_deps.remove_op(idx) block._sync_with_cpp() + for idx, op in reversed(list(enumerate(block.ops))): + if op.type == 'concat' and is_optimizer_op(op): + # remove inputs that not on this card + reserved_x = [] + for var_name in op.desc.input("X"): + if block.has_var(var_name): reserved_x.append(var_name) + op.desc.set_input('X', reserved_x) + block._sync_with_cpp() return def _add_broadcast_allreduce(self, block): @@ -459,6 +718,7 @@ class ShardingOptimizer(MetaOptimizerBase): def _init_comm(self): if self.hybrid_dp: + assert self._as_outer_parallelism == False, "hybrid dp is conflict when using sharding as outer parallelism" self.sharding_group_size = self.user_defined_strategy.sharding_configs[ "sharding_group_size"] self.sharding_ring_id = 0 @@ -485,13 +745,109 @@ class ShardingOptimizer(MetaOptimizerBase): self.global_word_size, self.sharding_group_size, self.dp_group_size) + self.pp_ring_id = -1 + self.pp_rank = -1 + self.pp_group_size = None + self.pp_group_endpoints = None + + # sharding parallelism is the only model parallelism in the current setting + self.mp_group_id = self.sharding_ring_id + self.mp_rank = self.sharding_rank + self.mp_group_size = self.sharding_group_size + self.mp_group_endpoints = self.sharding_group_endpoints[:] logging.info("Using Sharing&DP mode !") else: - self.sharding_ring_id = 0 - self.sharding_rank = self.global_rank - self.sharding_group_size = self.role_maker._worker_num() - self.sharding_group_endpoints = self.endpoints + if self._as_outer_parallelism: + self.sharding_ring_id = 1 + assert self.global_word_size > self._inner_parallelism_size, \ + "global_word_size: {} should be larger than inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size) + assert self.global_word_size % self._inner_parallelism_size == 0, \ + "global_word_size: {} should be divisible to the inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size) + self.sharding_rank = self.global_rank // self._inner_parallelism_size + self.sharding_group_size = self.role_maker._worker_num( + ) // self._inner_parallelism_size + _offset = self.global_rank % self._inner_parallelism_size + self.sharding_group_endpoints = [ + ep for idx, ep in enumerate(self.endpoints) + if idx % self._inner_parallelism_size == _offset + ] + + # the current entire model parallelism group is the combination of innert & sharding parallelism + self.mp_group_id = 2 + self.mp_rank = self.global_rank + self.mp_group_size = self.role_maker._worker_num() + self.mp_group_endpoints = self.endpoints[:] + logging.info("Using Sharing as Outer parallelism mode !") + + # print( + # "init the nccl comm for megatron paramllelism, this should be done in Megatron Metaoptimizer" + # ) + # partition_idx = self.global_rank // self._inner_parallelism_size + # magetron_endpoints = self.endpoints[ + # partition_idx * self._inner_parallelism_size:partition_idx * + # self._inner_parallelism_size + self._inner_parallelism_size] + # magetron_rank = self.global_rank % self._inner_parallelism_size + + # self._collective_helper._init_communicator( + # program=self._startup_program, + # current_endpoint=self.current_endpoint, + # endpoints=magetron_endpoints, + # rank=magetron_rank, + # ring_id=0, + # wait_port=True) + # logging.info("megatron group size: {}".format( + # self._inner_parallelism_size)) + # logging.info("megatron rank: {}".format(magetron_rank)) + # logging.info("megatron endpoints: {}".format( + # magetron_endpoints)) + if self.use_pipeline: + self.sharding_ring_id = 0 + self.sharding_group_size = self.user_defined_strategy.sharding_configs[ + 'sharding_group_size'] + self.sharding_rank = self.global_rank % self.sharding_group_size + assert self.sharding_group_size * self.pipeline_nodes == self.role_maker._worker_num( + ) + self.pp_ring_id = 1 + self.pp_rank = self.global_rank // self.sharding_group_size + self.sharding_group_endpoints = [ + ep for idx, ep in enumerate(self.endpoints) + if (idx // self.sharding_group_size) == self.pp_rank + ] + self.pp_group_size = self.pipeline_nodes + self.pp_group_endpoints = [ + ep for idx, ep in enumerate(self.endpoints) + if (idx % self.sharding_group_size) == self.sharding_rank + ] + self.dp_ring_id = -1 + self.dp_rank = -1 + self.dp_group_size = None + self.dp_group_endpoints = None + + logging.info("Using Sharing with pipeline !") + else: + self.sharding_ring_id = 0 + self.sharding_rank = self.global_rank + self.sharding_group_size = self.role_maker._worker_num() + self.sharding_group_endpoints = self.endpoints + + # sharding parallelism is the only model parallelism in the current setting + self.mp_group_id = self.sharding_ring_id + self.mp_rank = self.sharding_rank + self.mp_group_size = self.sharding_group_size + self.mp_group_endpoints = self.sharding_group_endpoints[:] + + logging.info("Using Sharing alone mode !") + + self.dp_ring_id = -1 + self.dp_rank = -1 + self.dp_group_size = None + self.dp_group_endpoints = None + + self.pp_ring_id = -1 + self.pp_rank = -1 + self.pp_group_size = None + self.pp_group_endpoints = None self.dp_ring_id = -1 self.dp_rank = -1 self.dp_group_size = None @@ -503,12 +859,42 @@ class ShardingOptimizer(MetaOptimizerBase): logging.info("global rank: {}".format(self.global_rank)) logging.info("sharding group_size: {}".format(self.sharding_group_size)) logging.info("sharding rank: {}".format(self.sharding_rank)) + logging.info("current model parallelism group_size: {}".format( + self.mp_group_size)) + logging.info("current model parallelism rank: {}".format(self.mp_rank)) logging.info("dp group size: {}".format(self.dp_group_size)) logging.info("dp rank: {}".format(self.dp_rank)) logging.info("current endpoint: {}".format(self.current_endpoint)) + logging.info("global word endpoints: {}".format(self.endpoints)) logging.info("sharding group endpoints: {}".format( self.sharding_group_endpoints)) + logging.info("current model parallelism group endpoints: {}".format( + self.mp_group_endpoints)) logging.info("dp group endpoints: {}".format(self.dp_group_endpoints)) - logging.info("global word endpoints: {}".format(self.endpoints)) return + + def _initialization_broadcast(self, startup_prog): + """ + this funtion is to ensure the initialization between dp group to be + identical when hybrid-dp is used. + """ + block = startup_prog.global_block() + params = [] + for param in block.iter_parameters(): + params.append(param) + block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': self.dp_ring_id, + 'root': 0, + OP_ROLE_KEY: OpRole.Forward + }) + block.append_op( + type='c_sync_comm_stream', + inputs={'X': params}, + outputs={'Out': params}, + attrs={'ring_id': self.dp_ring_id, + OP_ROLE_KEY: OpRole.Forward}) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 33e2e387a82758ba9cd59dc40d41fb5ad05ee29b..abf02851b9c2dcc7b33ed19b2f2a16f1b693313e 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -115,7 +115,7 @@ class ProgramStats(object): updated_min_idx = min_idx while idx_ > pre_segment_end_idx: if is_amp_cast(self.ops[idx_]): - _logger.debug("found amp-cast op: {}, : {}".format(self.ops[ + _logger.info("found amp-cast op: {}, : {}".format(self.ops[ idx_].desc.type(), self.ops[idx_].desc.input_arg_names()[ 0])) updated_min_idx = idx_ @@ -155,7 +155,7 @@ class ProgramStats(object): sorted_checkpoints = [] for name in checkpoints_name: if name not in self.var_op_deps: - _logger.debug( + _logger.info( "Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program." % name) elif self.var_op_deps[name]["var_as_output_ops"] == []: @@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars): new_op_desc = block.desc.append_op() new_op_desc.copy_from(desc) new_op_desc._set_attr(op_role_attr_name, backward) + if desc.has_attr('op_device'): + new_op_desc._set_attr('op_device', desc.attr('op_device')) result_descs.append(new_op_desc) return result_descs @@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block): new_op_desc = block.desc.append_op() new_op_desc.copy_from(desc) new_op_desc._set_attr(op_role_attr_name, backward) + if desc.has_attr('op_device'): + new_op_desc._set_attr('op_device', desc.attr('op_device')) result_descs.append(new_op_desc) return result_descs @@ -784,7 +788,6 @@ def _append_backward_ops_with_checkpoints_( start_idx = 0 pre_segment_end_idx = -1 while True: - _logger.debug("FW op range[0] - [{}]".format(len(ops))) if start_idx >= len(checkpoints_name) - 1: break # min_idx: checkpoint_1' s input op @@ -797,6 +800,9 @@ def _append_backward_ops_with_checkpoints_( min_idx = program_stat._update_segment_start( min_idx, pre_segment_end_idx) segments.append([min_idx, max_idx + 1]) + else: + _logger.info("Could not recompute op range [{}] - [{}] ".format( + min_idx, max_idx + 1)) start_idx += 1 @@ -806,15 +812,15 @@ def _append_backward_ops_with_checkpoints_( recompute_segments = segments for i, (idx1, idx2) in enumerate(recompute_segments): - _logger.debug("recompute segment[{}]".format(i)) - _logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( + _logger.info("recompute segment[{}]".format(i)) + _logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( ), ops[idx1].desc.input_arg_names())) - _logger.debug("segment end op: [{}]: [{}]".format(ops[ + _logger.info("segment end op: [{}]: [{}]".format(ops[ idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names())) - _logger.debug("recompute segment[{}]".format(i)) - _logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( + _logger.info("recompute segment[{}]".format(i)) + _logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( ), ops[idx1].desc.input_arg_names())) - _logger.debug("segment end op: [{}]: [{}]".format(ops[ + _logger.info("segment end op: [{}]: [{}]".format(ops[ idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names())) # 2) go through all forward ops and induct all variables that will be hold in memory @@ -825,9 +831,9 @@ def _append_backward_ops_with_checkpoints_( program_stat.get_out_of_subgraph_vars(segment[0], segment[1])) cross_vars = set(vars_should_be_hold) - set(checkpoints_name) - _logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ + _logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ len(cross_vars), cross_vars)) - _logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ + _logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ len(cross_vars), cross_vars)) # b. output of seed op should be kept in memory @@ -843,6 +849,7 @@ def _append_backward_ops_with_checkpoints_( vars_in_memory = vars_should_be_hold + checkpoints_name max_calculated_op_position = len(ops) + device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() if recompute_segments == []: gap_ops = ops[0:max_calculated_op_position] for op in reversed(gap_ops): @@ -852,6 +859,11 @@ def _append_backward_ops_with_checkpoints_( _pretty_op_desc_(op.desc, "with_sub_block")) grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, cpt.to_text(no_grad_dict[block.idx]), []) + # Set device for grad_op according to forward Op + if op.desc.has_attr(device_attr_name): + op_device = op.desc.attr(device_attr_name) + for op_desc in grad_op_desc: + op_desc._set_attr(device_attr_name, op_device) added_descs = _add_descs_to_block(grad_op_desc, local_block) grad_op_descs.extend(added_descs) grad_to_var.update(op_grad_to_var) @@ -866,6 +878,11 @@ def _append_backward_ops_with_checkpoints_( _pretty_op_desc_(op.desc, "with_sub_block")) grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, cpt.to_text(no_grad_dict[block.idx]), []) + # Set device for grad_op according to forward Op + if op.desc.has_attr(device_attr_name): + op_device = op.desc.attr(device_attr_name) + for op_desc in grad_op_desc: + op_desc._set_attr(device_attr_name, op_device) added_descs = _add_descs_to_block(grad_op_desc, local_block) grad_op_descs.extend(added_descs) grad_to_var.update(op_grad_to_var) @@ -888,6 +905,18 @@ def _append_backward_ops_with_checkpoints_( continue if name not in var_name_dict: var_name_dict[name] = name + var_suffix + + # we should create the rename var in subprog, otherwise its VarType will be BOOL + block.create_var( + name=var_name_dict[name], + shape=block.program.global_block().var(name).shape, + dtype=block.program.global_block().var(name).dtype, + type=block.program.global_block().var(name).type, + persistable=block.program.global_block().var( + name).persistable, + stop_gradient=block.program.global_block().var(name) + .stop_gradient) + # 3.a. add ops in current recompute_segment as forward recomputation ops buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block, vars_in_memory)