From e5b4dd7386486610a183460e88e21b8899bd1d55 Mon Sep 17 00:00:00 2001 From: danleifeng <52735331+danleifeng@users.noreply.github.com> Date: Mon, 11 Oct 2021 20:47:08 +0800 Subject: [PATCH] [heterps] add fuse_allreduce (#35131) * heterps:add fuse_allreduce op; test=develop * add program_mode in minimize for pslib mode;test=develop --- python/paddle/distributed/fleet/utils/fs.py | 13 +- .../fleet/parameter_server/pslib/__init__.py | 13 +- python/paddle/fluid/transpiler/collective.py | 267 +++++++++++++++++- 3 files changed, 284 insertions(+), 9 deletions(-) diff --git a/python/paddle/distributed/fleet/utils/fs.py b/python/paddle/distributed/fleet/utils/fs.py index d3f84d50ac8..f56580f8ca2 100644 --- a/python/paddle/distributed/fleet/utils/fs.py +++ b/python/paddle/distributed/fleet/utils/fs.py @@ -468,10 +468,17 @@ class HDFSClient(FS): self._bd_err_re = re.compile( r'\s?responseErrorMsg\s?\:.*, errorCode\:\s?[0-9]+, path\:') - def _run_cmd(self, cmd, redirect_stderr=False): + def _run_cmd(self, cmd, redirect_stderr=False, retry_times=5): exe_cmd = "{} -{}".format(self._base_cmd, cmd) - ret, output = core.shell_execute_cmd(exe_cmd, 0, 0, redirect_stderr) - ret = int(ret) + ret = 0 + output = None + retry_sleep_second = 3 + for x in range(retry_times + 1): + ret, output = core.shell_execute_cmd(exe_cmd, 0, 0, redirect_stderr) + ret = int(ret) + if ret == 0: + break + time.sleep(retry_sleep_second) if ret == 134: raise FSShellCmdAborted(cmd) return ret, output.splitlines() diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index d245ce222ca..78af7fd65dc 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -1091,7 +1091,8 @@ class DownpourOptimizer(DistributedOptimizer): scopes=None, startup_programs=None, parameter_list=None, - no_grad_set=None): + no_grad_set=None, + program_mode="all_reduce"): """ minimize a program through loss, loss can be a list in DistributedOptimizer. Note that in parameter server mode, a worker will not get anything about optimize_os @@ -1105,6 +1106,7 @@ class DownpourOptimizer(DistributedOptimizer): in `parameter_list`. parameter_list (list): list of Variables to update. no_grad_set (set|None): set of Variables should be ignored. + program_mode (str|"all_reduce"): grad action for grogram when use_ps_gpu. Returns: tuple: (optimize_ops, params_grads) which are, list of operators appended; and list of (param, grad) Variables pair for optimization. @@ -1139,12 +1141,17 @@ class DownpourOptimizer(DistributedOptimizer): if opt_info["use_ps_gpu"]: from paddle.fluid.transpiler.collective import MultiThread # check start program - + if program_mode not in [ + "all_reduce", "fuse_all_reduce", "all_gather" + ]: + raise ValueError("You should set program_mode in [ all_reduce, \ + fuse_all_reduce, all_gather ]") env = self.get_dist_env() if not isinstance(losses, list): startup_programs = [startup_programs] for i in range(0, len(startup_programs)): - t = MultiThread() + + t = MultiThread(trans_mode=program_mode) start_program = startup_programs[i] main_program = programs[i] t.transpile( diff --git a/python/paddle/fluid/transpiler/collective.py b/python/paddle/fluid/transpiler/collective.py index ec8602ec7e6..ea88a89e682 100644 --- a/python/paddle/fluid/transpiler/collective.py +++ b/python/paddle/fluid/transpiler/collective.py @@ -65,7 +65,7 @@ class Collective(object): self.main_program = default_main_program() self.nranks = len(endpoints) - if self.nranks == 1 and self.mode != "single_process_multi_thread": + if self.nranks == 1 and self.mode != "single_process_multi_thread" and self.mode != "box": raise ValueError('the number of endpoints must > 1') if rank < 0: @@ -441,9 +441,14 @@ class MultiThread(GradAllReduce): ''' ''' - def __init__(self, nrings=1): + def __init__(self, nrings=1, trans_mode="all_reduce"): GradAllReduce.__init__(self, nrings) - self.mode = "single_process_multi_thread" + self.mode = "box" + self.trans_mode = trans_mode + self.fuse_grad_size_in_num = 128 + gpu_nums = os.getenv("FLAGS_selected_gpus", + "0,1,2,3,4,5,6,7,8").split(",") + self.gpu_num = len(gpu_nums) def _transpile_startup_program(self): if len(self.endpoints) > 1: @@ -460,3 +465,259 @@ class MultiThread(GradAllReduce): print("begin to _transpile_startup_program for single-node") block = self.startup_program.global_block() block.append_op(type='c_comm_init_all', attrs={'ring_id': 0}) + + def _transpile_main_program(self): + self._insert_scale_loss_grad_ops() + if self.trans_mode == "all_gather": + print("begin to transpile in all-gather mode") + self.allgather_ranks = self.nranks * self.gpu_num + self._insert_allgather_ops() + self._update_adam_ops() + elif self.trans_mode == "fuse_all_reduce": + print("begin to transpile in fuse all-reduce mode") + self._insert_fuse_allreduce_ops() + else: + print("begin to transpile in all-reduce mode") + self._insert_allreduce_ops() + + def _insert_allgather_ops(self): + """ + insert allgather op to the main_program + """ + block = self.main_program.global_block() + ring_id = -1 + grad = None + for idx, op in reversed(list(enumerate(block.ops))): + if self._is_backward_op(op) and \ + self.op_role_var_key in op.attr_names: + op_role_var = op.all_attrs()[self.op_role_var_key] + if len(op_role_var) == 0: + continue + assert len(op_role_var) % 2 == 0 + + offset = idx + for i in range(0, len(op_role_var), 2): + param = block.vars[op_role_var[i]] + new_grad_var = block.create_var( + name=op_role_var[i] + "_allgather", + shape=[self.allgather_ranks] + list(param.shape), + persistable=False, + dtype=core.VarDesc.VarType.FP32, + stop_gradient=True) + grad = block.vars[op_role_var[i + 1]] + if param.is_distributed: # no need to care: used in PLSC + continue + + if offset == idx: + offset += 1 + block._insert_op( + offset, + type='c_sync_calc_stream', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={self.op_role_key: OpRole.Backward}) + offset += 1 + + # As we search ops reversedly, we should insert c_allgather + # op in the same way to keep the ring_id alternate + ring_id = (ring_id + 1) % self.nrings + block._insert_op( + offset, + type='c_allgather', + inputs={'X': grad}, + outputs={'Out': new_grad_var}, + attrs={ + 'nranks': self.allgather_ranks, + 'ring_id': ring_id, + self.op_role_key: OpRole.Backward + }) + + if grad is None: + return + + for idx, op in enumerate(block.ops): + if self._is_optimizer_op(op): + for ring_id in range(self.nrings): + block._insert_op( + idx + ring_id, + type='c_sync_comm_stream', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={ + 'ring_id': ring_id, + self.op_role_key: OpRole.Backward + }) + break + + def _update_adam_ops(self): + """ + remove the original adam op, and add new adam ops + """ + block = self.main_program.global_block() + + for idx, op in reversed(list(enumerate(block.ops))): + if self._is_optimizer_op(op): + offset = idx + if op.type != 'adam' and op.type != 'lamb': # filter out scale op + continue + param_name = op.input("Param")[0] + inputs = { + "Param": block.vars[op.input("Param")[0]], + "LearningRate": block.vars[op.input("LearningRate")[0]], + "Moment1": block.vars[op.input("Moment1")[0]], + "Moment2": block.vars[op.input("Moment2")[0]], + "Beta1Pow": block.vars[op.input("Beta1Pow")[0]], + "Beta2Pow": block.vars[op.input("Beta2Pow")[0]] + } + outputs = { + "ParamOut": block.vars[op.output("ParamOut")[0]], + "Moment1Out": block.vars[op.output("Moment1Out")[0]], + "Moment2Out": block.vars[op.output("Moment2Out")[0]], + "Beta1PowOut": block.vars[op.output("Beta1PowOut")[0]], + "Beta2PowOut": block.vars[op.output("Beta2PowOut")[0]] + } + attrs = { + "epsilon": op.attr('epsilon'), + "beta1": op.attr('beta1'), + "beta2": op.attr('beta2'), + "lazy_mode": op.attr('lazy_mode'), + "min_row_size_to_use_multithread": + op.attr('min_row_size_to_use_multithread') + } + split_vars = [ + block.create_var( + name=param_name + "_" + str(i), + shape=block.vars[op.input("Param")[0]].shape, + persistable=False, + dtype=core.VarDesc.VarType.FP32, + stop_gradient=True) for i in range(self.allgather_ranks) + ] + block._insert_op( + offset, + type="split", + inputs={ + 'X': block.vars[op.input("Param")[0] + "_allgather"] + }, + outputs={'Out': split_vars}, + attrs={'num': self.allgather_ranks, + 'axis': 0}) + offset += 1 + + for i in range(self.allgather_ranks): + inputs["Grad"] = split_vars[i] + block._insert_op( + offset, + type=op.type, + inputs=inputs, + outputs=outputs, + attrs=attrs) + offset += 1 + # remove the original adam op + block._remove_op(offset) + + def _insert_fuse_allreduce_ops(self): + """ + insert coalesce_tensor and all reduce ops + """ + block = self.main_program.global_block() + ring_id = 0 % self.nrings + grad = None + param_grads = [] + # find all grad params + for op in reversed(block.ops): + if self._is_backward_op(op) and \ + self.op_role_var_key in op.attr_names: + op_role_var = op.all_attrs()[self.op_role_var_key] + if len(op_role_var) == 0: + continue + assert len(op_role_var) % 2 == 0, "vars need to be one param var followed by one grad var, " \ + "but got odd number of vars" + for i in range(0, len(op_role_var), 2): + param_name = op_role_var[i] + param = block.var(param_name) + grad_name = op_role_var[i + 1] + grad = block.var(grad_name) + if param.is_distributed: + continue + param_grads.append(grad) + if grad is None: + return + + segments = [] + last_dtype = None + # split the grad based on dtype and fused size + for var in param_grads: + if len(segments) == 0 \ + or len(segments[-1]) == self.fuse_grad_size_in_num \ + or var.dtype != last_dtype: + segments.append([var]) + last_dtype = var.dtype + else: + segments[-1].append(var) + + fused_vars = [] + for idx, op in enumerate(block.ops): + if self._is_optimizer_op(op): + for segment in segments: + # insert coalesce tensor + tmp_var = block.create_var( + name=unique_name.generate('FusedOutput_{}'.format( + segment[0].name)), + dtype=segment[0].dtype, + persistable=False, + stop_gradient=True) + fused_vars.append(tmp_var) + block._insert_op( + idx, + type="coalesce_tensor", + inputs={"Input": segment}, + outputs={"Output": segment, + "FusedOutput": tmp_var}, + attrs={ + "copy_data": True, + "use_align": True, + "dtype": segment[0].dtype, + self.op_role_key: OpRole.Backward + }) + break + + # insert the allreduce_sum op + for idx, op in enumerate(block.ops): + if self._is_optimizer_op(op): + for fused_var in fused_vars: + block._insert_op( + idx, + type='c_allreduce_sum', + inputs={'X': fused_var}, + outputs={'Out': fused_var}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': False, + self.op_role_key: OpRole.Backward + }) + block._insert_op( + idx, + type='c_sync_calc_stream', + inputs={'X': fused_var}, + outputs={'Out': fused_var}, + attrs={self.op_role_key: OpRole.Backward}) + break + + if len(fused_vars) == 0: + block._sync_with_cpp() + return + + # insert the sync comm op + for idx, op in enumerate(block.ops): + if self._is_optimizer_op(op): + block._insert_op( + idx, + type='c_sync_comm_stream', + inputs={'X': fused_vars[0]}, + outputs={'Out': fused_vars[0]}, + attrs={ + 'ring_id': ring_id, + self.op_role_key: OpRole.Backward + }) + break + block._sync_with_cpp() -- GitLab