# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import paddle import paddle.compat as cpt from ..ps.utils.public import * from paddle.framework import core from paddle.distributed.passes.pass_base import PassBase, register_pass from paddle.fluid.transpiler.details.program_utils import delete_ops from paddle.fluid.transpiler.collective import SingleProcessMultiThread from _collections import deque, defaultdict from paddle.fluid.framework import Program, Parameter @register_pass("append_send_ops_pass") class AppendSendOpsPass(PassBase): # 该 pass 被多种模式复用 def __init__(self): super(AppendSendOpsPass, self).__init__() def _check_self(self): return True def _check_conflict(self, other_pass): return True def _append_send_op(self, program, union_vars, queue, is_sparse, table_id, ps_mode): if queue == STEP_COUNTER: send_input_vars = [] else: send_input_vars = [ program.global_block().vars[union_var] for union_var in union_vars ] dummy_output = [] if ps_mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]: dummy_output = program.global_block().create_var( name=framework.generate_control_dev_var_name()) program.global_block().append_op(type="send", inputs={"X": send_input_vars}, outputs={"Out": dummy_output}, attrs={ "send_varnames": [queue], "is_sparse": is_sparse, "table_id": table_id, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) return dummy_output def _append_barrier_op(self, program, dummys, trainer_id): program.global_block().append_op(type="send_barrier", inputs={"X": dummys}, outputs={"Out": []}, attrs={ "trainer_id": trainer_id, "half_async": True, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) def _apply_single_impl(self, main_program, startup_program, pass_ctx): attrs = pass_ctx._attrs ps_mode = attrs['ps_mode'] if ps_mode == DistributedMode.GEO: send_ctx = get_geo_trainer_send_context(attrs) # geo 模式 elif attrs['is_heter_ps_mode'] == True: print("is_heter_ps_mode in append_send_ops_pass!!") send_ctx = get_the_one_send_context(attrs, split_dense_table=True) else: send_ctx = get_the_one_send_context(attrs) # async、sync 等各种模式 dummys = [] for merged_name, send in send_ctx.items(): if send.is_sparse() and ps_mode != DistributedMode.GEO: continue if send.program_id() != id(attrs['loss'].block.program): continue is_sparse = 1 if send.is_sparse() else 0 is_sparse = 2 if send.is_distributed() else is_sparse dummys.append( self._append_send_op(main_program, send.origin_varnames(), merged_name, is_sparse, send.table_id(), ps_mode)) if ps_mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]: trainer_id = get_role_id(attrs['role_maker']) self._append_barrier_op(main_program, dummys, trainer_id) @register_pass("distributed_ops_pass") class DistributedOpsPass(PassBase): def __init__(self): super(DistributedOpsPass, self).__init__() self.w_2_table_id = {} self.emb_size = {} def _check_self(self): return True def _check_conflict(self, other_pass): return True def _push_sparse_fuse(self, _program, push_sparse_ops, attrs, use_cvm_op): if attrs['use_ps_gpu']: return if len(push_sparse_ops) == 0: return show = None clk = None use_entry = False for param, ops in push_sparse_ops.items(): op_first = ops[0] break if op_first.has_attr("entry"): entry = op_first.attr("entry") entry = entry.split(':') if len(entry) == 3 and entry[0] == 'show_click_entry': show_var_name = entry[1] click_var_name = entry[2] if show_var_name in _program.global_block( ).vars and click_var_name in _program.global_block().vars: show = _program.global_block().vars[show_var_name] clk = _program.global_block().vars[click_var_name] use_entry = True else: warnings.warn( 'ShowClickEntry configured, but cannot find show/click var, will not use' ) if not use_entry: print('ShowClickEntry not configured, will not use') show = _program.global_block().create_var( name="show", dtype=core.VarDesc.VarType.FP32, persistable=False, stop_gradient=True) _program.global_block()._insert_op(index=0, type='fill_constant', inputs={}, outputs={'Out': show}, attrs={ 'shape': [1], 'dtype': show.dtype, 'value': 1, }) clk = _program.global_block().create_var( name="clk", dtype=core.VarDesc.VarType.FP32, persistable=False, stop_gradient=True) _program.global_block()._insert_op(index=0, type='fill_constant', inputs={}, outputs={'Out': clk}, attrs={ 'shape': [1], 'dtype': clk.dtype, 'value': 0, }) for param, ops in push_sparse_ops.items(): all_ops = _program.global_block().ops op_idxs = [all_ops.index(op) for op in ops] inputs = [ _program.global_block().vars[op.input("Ids")[0]] for op in ops ] w = _program.global_block().vars[ops[0].output("W@GRAD")[0]] table_id = self.w_2_table_id[param] padding_idx = ops[0].attr("padding_idx") is_distributed = ops[0].attr("is_distributed") op_type = ops[0].type slots = [op.attr("slot") for op in ops] print('debug zcb slots: ', slots) outputs = [ _program.global_block().vars[op.input("Out@GRAD")[0]] for op in ops ] for idx in op_idxs[::-1]: _program.global_block()._remove_op(idx) _program.global_block().append_op(type="distributed_push_sparse", inputs={ "Ids": inputs, 'W': w, "Outputs": outputs, "Shows": show, "Clicks": clk, }, outputs={"Outputs": outputs}, attrs={ "is_distributed": is_distributed, "padding_idx": padding_idx, "table_id": table_id, "size": self.emb_size[param], "use_cvm_op": use_cvm_op, "slots": slots }) def _pull_sparse_fuse(self, _program, pull_sparse_ops, attrs, send_ctx): def dag_check_up_and_reorder(program, inputs, outputs): global_block = program.global_block() min_output_index = len(global_block.ops) max_input_index = -1 input_indexes = [0] * len(global_block.ops) output_indexes = [0] * len(global_block.ops) for idx, op in enumerate(global_block.ops): for i in range(0, len(op.output_names)): if input_indexes[idx] == 1: break outs = op.output(op.output_names[i]) for in_id, in_var in enumerate(inputs): if in_var.name in outs: input_indexes[idx] = 1 max_input_index = max(max_input_index, idx) break for i in range(0, len(op.input_names)): if output_indexes[idx] == 1: break ins = op.input(op.input_names[i]) for out_id, out_var in enumerate(outputs): if out_var.name in ins: output_indexes[idx] = 1 min_output_index = min(min_output_index, idx) for i in range(len(global_block.ops)): if input_indexes[i] == 1 and output_indexes[i] == 1: warnings.warn( "unable to re-arrange dags order to combine distributed embedding ops because a op both needs embedding table's output as input and produces ids as the same embedding table's input" ) return if min_output_index < max_input_index: move_ops = [] for i in range(min_output_index + 1, len(input_indexes)): if input_indexes[i] == 1: move_ops.append((global_block.ops[i], i)) for i, op in enumerate(move_ops): queue = list() visited = set() queue.append(op[1]) visited.add(op[0]) start = 0 while start < len(queue): pos = queue[start] op = global_block.ops[pos] op_inputs = [] for k in range(0, len(op.input_names)): ins = op.input(op.input_names[k]) op_inputs.append(ins) for j in range(pos - 1, min_output_index - 1, -1): op1 = global_block.ops[j] if op1 in visited: continue found = False for k in range(0, len(op1.output_names)): outs = op1.output(op1.output_names[k]) for t in range(len(op_inputs)): for y in op_inputs[t]: if y in outs: found = True break if found: break if found: break if found: if output_indexes[j] == True: warnings.warn( "unable to re-arrange dags order to combine distributed embedding ops" ) return queue.append(j) visited.add(global_block.ops[j]) start = start + 1 queue.sort() for index in queue: desc = global_block.desc._insert_op(min_output_index) desc.copy_from(global_block.ops[index].desc) global_block.desc._remove_op(index + 1, index + 2) global_block.ops[index].desc = desc insert_op = global_block.ops.pop(index) input_state = input_indexes.pop(index) output_state = output_indexes.pop(index) global_block.ops.insert(min_output_index, insert_op) input_indexes.insert(min_output_index, input_state) output_indexes.insert(min_output_index, output_state) min_output_index = min_output_index + 1 assert global_block.desc.op_size() == len(global_block.ops) for i in range(len(global_block.ops)): assert global_block.desc.op(i) == global_block.ops[i].desc if attrs['use_ps_gpu']: gpups_inputs_idxs = list() gpups_outputs_idxs = list() gpups_inputs = list() gpups_outputs = list() gpups_w_size = list() gpups_min_distributed_idx = len(_program.global_block().ops) + 1 for param, ops in pull_sparse_ops.items(): all_ops = _program.global_block().ops op_device = "" if attrs['is_heter_ps_mode']: op_device = ops[0].attr("op_device") inputs = [ _program.global_block().vars[op.input("Ids")[0]] for op in ops ] w = _program.global_block().vars[ops[0].input("W")[0]] self.emb_size[param] = w.shape[1] grad_name = attrs['param_name_to_grad_name'][w.name] table_id = -1 for name, ctx in send_ctx.items(): if grad_name in ctx.origin_varnames(): table_id = ctx.table_id() if table_id == -1: raise ValueError( "can not find suitable sparse table, please check") self.w_2_table_id[param] = table_id padding_idx = ops[0].attr("padding_idx") is_distributed = ops[0].attr("is_distributed") op_type = ops[0].type outputs = [ _program.global_block().vars[op.output("Out")[0]] for op in ops ] dag_check_up_and_reorder(_program, inputs, outputs) op_idxs = [all_ops.index(op) for op in ops] for idx in op_idxs[::-1]: _program.global_block()._remove_op(idx) inputs_idxs = [-1] * len(inputs) outputs_idxs = [len(_program.global_block().ops) + 1] * len(outputs) for idx, op in enumerate(_program.global_block().ops): for i in range(0, len(op.output_names)): outs = op.output(op.output_names[i]) for in_id, in_var in enumerate(inputs): if in_var.name in outs: inputs_idxs[in_id] = max(idx, inputs_idxs[in_id]) for i in range(0, len(op.input_names)): ins = op.input(op.input_names[i]) for out_id, out_var in enumerate(outputs): if out_var.name in ins: outputs_idxs[out_id] = min(idx, outputs_idxs[out_id]) if attrs['use_ps_gpu']: gpups_inputs_idxs.extend(inputs_idxs) gpups_outputs_idxs.extend(outputs_idxs) gpups_inputs.extend(inputs) gpups_outputs.extend(outputs) gpups_w_size.extend([w.shape[1]] * len(inputs)) gpups_min_distributed_idx = min(min(op_idxs), gpups_min_distributed_idx) continue if min(outputs_idxs) - max(inputs_idxs) >= 1: if max(inputs_idxs) == -1: distributed_idx = min(op_idxs) else: distributed_idx = max(inputs_idxs) + 1 _program.global_block()._insert_op( index=distributed_idx, type="distributed_lookup_table", inputs={ "Ids": inputs, 'W': w }, outputs={"Outputs": outputs}, attrs={ "is_distributed": is_distributed, "padding_idx": padding_idx, "table_id": table_id, "lookup_table_version": op_type, "op_device": op_device }) else: for i in range(len(inputs_idxs)): distributed_idx = op_idxs[i] _program.global_block()._insert_op( index=distributed_idx, type="distributed_lookup_table", inputs={ "Ids": [inputs[i]], 'W': w }, outputs={"Outputs": [outputs[i]]}, attrs={ "is_distributed": is_distributed, "padding_idx": padding_idx, "table_id": table_id, "lookup_table_version": op_type, "op_device": op_device }) if attrs['use_ps_gpu'] and len(gpups_inputs) > 0: if max(gpups_inputs_idxs) > 0: raise ValueError("There can't be ops before embedding in gpups") _program.global_block()._insert_op(index=gpups_min_distributed_idx, type="pull_gpups_sparse", inputs={ "Ids": gpups_inputs, }, outputs={"Out": gpups_outputs}, attrs={ "size": gpups_w_size, "is_distributed": True, "is_sparse": True }) PSGPU = paddle.fluid.core.PSGPU() try: gpu_slot = [int(var.name) for var in gpups_inputs] except (ValueError): raise ValueError( "The slot name in gpups Should be able to convert to integer." ) PSGPU.set_slot_vector(gpu_slot) gpu_mf_sizes = [x - 3 for x in gpups_w_size] PSGPU.set_slot_dim_vector(gpu_mf_sizes) def _get_pull_sparse_ops(self, _program, attrs): pull_sparse_ops = {} pull_sparse_ids = {} push_sparse_ops = {} ops = {} use_cvm_op = False for op in _program.global_block().ops: if op.type in SPARSE_OP_TYPE_DICT.keys() \ and op.attr('remote_prefetch') is True: param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0] if attrs['is_heter_ps_mode'] and not attrs['is_fl_ps_mode']: # TODO: trick for matchnet, need to modify for heter_ps param_name += op.input("Ids")[0][0] ops = pull_sparse_ops.get(param_name, []) ops.append(op) pull_sparse_ops[param_name] = ops ids = pull_sparse_ids.get(param_name, []) ids.append(op.input("Ids")[0]) pull_sparse_ids[param_name] = ids if op.type == 'cvm': use_cvm_op = True for op in _program.global_block().ops: if op.type in SPARSE_GRAD_OP_TYPE_DICT.keys(): param_name = op.input(SPARSE_GRAD_OP_TYPE_DICT[op.type])[0] if param_name in pull_sparse_ids and op.input( "Ids")[0] in pull_sparse_ids[param_name]: ops = push_sparse_ops.get(param_name, []) ops.append(op) push_sparse_ops[param_name] = ops return pull_sparse_ops, push_sparse_ops, use_cvm_op def _apply_single_impl(self, main_program, startup_program, pass_ctx): attrs = pass_ctx._attrs pull_sparse_ops, push_sparse_ops, use_cvm_op = self._get_pull_sparse_ops( main_program, attrs) print("is_heter_ps_mode in distributed_ops_pass {}?".format( attrs['is_heter_ps_mode'])) send_ctx = get_the_one_send_context( attrs, split_dense_table=attrs['is_heter_ps_mode']) self._pull_sparse_fuse(main_program, pull_sparse_ops, attrs, send_ctx) self._push_sparse_fuse(main_program, push_sparse_ops, attrs, use_cvm_op) @register_pass("delete_optimizer_pass") class DeleteOptimizesPass(PassBase): def __init__(self): super(DeleteOptimizesPass, self).__init__() def _check_self(self): return True def _check_conflict(self, other_pass): return True def _delete_optimizer_op_and_vars(self, _program, optimize_ops): optimize_vars = [] optimize_op_role_vars = [] optimize_need_delete_vars = [] for op in optimize_ops: optimize_vars.extend(op.input_arg_names) optimize_op_role_vars.extend(op.attr("op_role_var")) optimize_vars = list(set(optimize_vars)) optimize_op_role_vars = list(set(optimize_op_role_vars)) for var in optimize_vars: if var not in optimize_op_role_vars: optimize_need_delete_vars.append(var) need_delete_optimize_vars = list(set(optimize_need_delete_vars)) delete_ops(_program.global_block(), optimize_ops) for var in need_delete_optimize_vars: if _program.global_block().has_var(var): _program.global_block()._remove_var(var) def _add_lr_var(self, main_program, attrs): # Todo: hard code for pe lr_var = attrs['origin_main_program'].global_block( ).vars["learning_rate_0"] main_program.global_block().create_var(name=lr_var.name, shape=lr_var.shape, dtype=lr_var.dtype, type=lr_var.type, lod_level=lr_var.lod_level, persistable=True) def _apply_single_impl(self, main_program, startup_program, pass_ctx): attrs = pass_ctx._attrs optimizer_ops = get_optimize_ops(main_program) lr_ops = get_lr_ops(main_program) optimizer_ops.extend(lr_ops) self._delete_optimizer_op_and_vars(main_program, optimizer_ops) if hasattr(attrs['origin_main_program'], 'lr_sheduler'): self._add_lr_var(main_program, attrs) @register_pass("delete_extra_optimizer_pass") class DeleteExtraOptimizerPass(PassBase): def __init__(self): super(DeleteExtraOptimizerPass, self).__init__() def _check_self(self): return True def _check_conflict(self, other_pass): return True def _apply_single_impl(self, main_program, startup_program, pass_ctx): attrs = pass_ctx._attrs optimize_vars = [] optimize_op_role_vars = [] optimize_need_delete_vars = [] for op in get_optimize_ops(main_program): optimize_vars.extend(op.input_arg_names) optimize_op_role_vars.extend(op.attr("op_role_var")) optimize_vars = list(set(optimize_vars)) optimize_op_role_vars = list(set(optimize_op_role_vars)) for var in optimize_vars: if var not in optimize_op_role_vars: optimize_need_delete_vars.append(var) need_delete_optimize_vars = list(set(optimize_need_delete_vars)) init_ops = [] for var in need_delete_optimize_vars: param_init_op = [] for op in startup_program.global_block().ops: if var in op.output_arg_names: param_init_op.append(op) init_ops.extend(param_init_op) delete_ops(startup_program.global_block(), init_ops) for var in need_delete_optimize_vars: if startup_program.global_block().has_var(var): startup_program.global_block()._remove_var(var) @register_pass("fake_init_ops_pass") class FakeInitOpsPass(PassBase): def __init__(self): super(FakeInitOpsPass, self).__init__() def _check_self(self): return True def _check_conflict(self, other_pass): return True def _get_sparse_table_names(self, attrs): dist_varnames = get_sparse_tablenames(attrs['origin_main_programs'], True) sparse_varnames = get_sparse_tablenames(attrs['origin_main_programs'], False) return list(set(dist_varnames + sparse_varnames)) def _fake_init_sparsetable(self, program, sparse_table_names): # delete table init op for table_name in sparse_table_names: table_var = program.global_block().vars[table_name] table_param_init_op = [] for op in program.global_block().ops: if table_name in op.output_arg_names: table_param_init_op.append(op) init_op_num = len(table_param_init_op) if init_op_num != 1: raise ValueError("table init op num should be 1, now is " + str(init_op_num)) table_init_op = table_param_init_op[0] program.global_block().append_op( type="fake_init", inputs={}, outputs={"Out": table_var}, attrs={"shape": table_init_op.attr('shape')}) delete_ops(program.global_block(), table_param_init_op) def _apply_single_impl(self, main_program, startup_program, pass_ctx): attrs = pass_ctx._attrs sparse_tables = self._get_sparse_table_names(attrs) self._fake_init_sparsetable(startup_program, sparse_tables) @register_pass("ps_gpu_pass") class PsGpuPass(PassBase): def __init__(self): super(PsGpuPass, self).__init__() def _check_self(self): return True def _check_conflict(self, other_pass): return True def _add_push_box_sparse_op(self, program): insert_index = -1 for idx, op in list(enumerate(program.global_block().ops)): if op.type == "lookup_table_grad": insert_index = idx for op in program.global_block().ops: if op.type != "pull_box_sparse" and op.type != "pull_gpups_sparse": continue grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, cpt.to_text(set()), []) for op_desc in grad_op_desc: new_op_desc = program.global_block().desc._insert_op( insert_index + 1) new_op_desc.copy_from(op_desc) new_op_desc._set_attr(op_role_attr_name, backward) new_op = paddle.fluid.framework.Operator( program.global_block(), new_op_desc) program.global_block().ops.insert(insert_index + 1, new_op) program.global_block()._sync_with_cpp() def _remove_optimizer_var(self, program): embedding_w = {} for idx, op in list(enumerate(program.global_block().ops)): if op.type == "lookup_table_grad": for name in op.input("W"): embedding_w[name] = 1 optimize_vars = [] optimize_op_role_vars = [] optimize_need_delete_vars = [] for op in get_optimize_ops(program): for name in op.input("Param"): if name in embedding_w: optimize_op_role_vars.extend(op.attr("op_role_var")) for key_name in op.input_names: if key_name == "LearningRate": continue for var in op.input(key_name): optimize_vars.append(var) optimize_vars = list(set(optimize_vars)) optimize_op_role_vars = list(set(optimize_op_role_vars)) for var in optimize_vars: if var not in optimize_op_role_vars: optimize_need_delete_vars.append(var) need_delete_optimize_vars = list(set(optimize_need_delete_vars)) for name in need_delete_optimize_vars: if program.global_block().has_var(name): program.global_block()._remove_var(name) def _remove_lookup_table_grad_op_and_var(self, program): lookup_table_grad_var = {} remove_op_index = [] remove_var = [] for idx, op in list(enumerate(program.global_block().ops)): if op.type == "lookup_table_grad": for name in op.output("W@GRAD"): lookup_table_grad_var[name] = 1 remove_op_index.append(idx) remove_var.append(name) for name in op.input("W"): lookup_table_grad_var[name] = 1 for idx, op in list(enumerate(program.global_block().ops)): if op.type == "pull_box_sparse" or op.type == "pull_gpups_sparse": continue for key_name in op.input_names: for var in op.input(key_name): if var in lookup_table_grad_var: remove_op_index.append(idx) break remove_op_index = list(set(remove_op_index)) remove_op_index.sort(reverse=True) for idx in remove_op_index: program.global_block()._remove_op(idx) for name in remove_var: program.global_block()._remove_var(name) def _apply_single_impl(self, main_program, startup_program, pass_ctx): attrs = pass_ctx._attrs self._add_push_box_sparse_op(main_program) self._remove_optimizer_var(main_program) self._remove_lookup_table_grad_op_and_var(main_program) @register_pass("ps_transpile_pass") class PsTranspilePass(PassBase): def __init__(self): super(PsTranspilePass, self).__init__() def _check_self(self): return True def _check_conflict(self, other_pass): return True def _apply_single_impl(self, main_program, startup_program, pass_ctx): attrs = pass_ctx._attrs t = SingleProcessMultiThread() env = get_dist_env() t.transpile(startup_program=startup_program, main_program=main_program, rank=env["trainer_id"], endpoints=env["trainer_endpoints"], current_endpoint=env['current_endpoint'], wait_port=False) @register_pass("split_heter_worker_ops_pass") class SplitHeterWorkerOpsPass(PassBase): def __init__(self): super(SplitHeterWorkerOpsPass, self).__init__() def _check_self(self): return True def _check_conflict(self, other_pass): return True def _create_heter_program(self, program, attrs, heter_program, program_block_ops_list, heter_ops, block_var_detail): # This function mainly includes the following contents: # 1. For every heter block: # a) copy heter device op from origin program # b) create variables which belong to heter op: # -> if variable is persistable, clone it in global_scope # -> if variable is temp, create it in heter block # c) create communicate related op as follow: # joint_var.0_1 -> slice -> reshape -> origin_var # origin_var -> origin_program # reshape -> concat -> joint_var.1_2 # d) copy send op from origin program for var@grad which loacted in current heter block # e) re-check every op in current blcok if its device is not current heter devie # 2. Create send op for step counter in last heter-block # 3. Create Listen&Serv OP and Send&Recv OP for distributed training # 4. update CompileTimeStrategy for heter_program optimizer_block = [] grad_to_block_id = [] send_grad_var_list = [] pre_block_idx = heter_program.num_blocks - 1 role_maker = attrs['role_maker'] current_device = role_maker._heter_device_type().lower() stage_id = int(role_maker._get_stage_id()) heter_block_ops_forward = program_block_ops_list[stage_id - 1]["forward"] heter_block_ops_backward = program_block_ops_list[stage_id - 1]["backward"] heter_block = heter_program._create_block(pre_block_idx) optimizer_block.append(heter_block) for _, op in enumerate(heter_block_ops_forward): block_append_op(heter_program, program, heter_block, op) entrance_vars = block_var_detail[stage_id - 1]["forward"]["entrance"] add_vars_by_var_list(entrance_vars, program, heter_program, heter_block) exit_vars = block_var_detail[stage_id - 1]["forward"]["exit"] add_vars_by_var_list(exit_vars, program, heter_program, heter_block) first_op_index_fp = len(heter_block.ops) if stage_id < len(program_block_ops_list): heter_block_bp = heter_program._create_block(pre_block_idx) optimizer_block.append(heter_block_bp) for _, op in enumerate(heter_block_ops_backward): block_append_op(heter_program, program, heter_block_bp, op) bp_entrance_vars = block_var_detail[stage_id - 1]["backward"]["entrance"] add_vars_by_var_list(bp_entrance_vars, program, heter_program, heter_block_bp) bp_exit_vars = block_var_detail[stage_id - 1]["backward"]["exit"] add_vars_by_var_list(bp_exit_vars, program, heter_program, heter_block_bp) backward_comm_info = get_communicate_var_info(program, stage_id, bp_entrance_vars, type="backward") grad_to_block_id.append(backward_comm_info["block_input_var_name"] + ":" + str(heter_block_bp.idx)) else: for _, op in enumerate(heter_block_ops_backward): block_append_op(heter_program, program, heter_block, op) bp_entrance_vars = block_var_detail[stage_id - 1]["backward"]["entrance"] add_vars_by_var_list(bp_entrance_vars, program, heter_program, heter_block) bp_exit_vars = block_var_detail[stage_id - 1]["backward"]["exit"] add_vars_by_var_list(bp_exit_vars, program, heter_program, heter_block) heter_block_bp = heter_block forward_comm_info = get_communicate_var_info(program, stage_id, entrance_vars, type="forward") grad_to_block_id.append(forward_comm_info["block_input_var_name"] + ":" + str(heter_block.idx)) first_op_index_bp = len(heter_block_bp.ops) if stage_id <= len(block_var_detail) - 1: static_var = insert_communicate_op(program, role_maker, heter_block, stage_id, first_op_index_fp, block_var_detail, current_device) static_var_bp = insert_communicate_op(program, role_maker, heter_block_bp, stage_id, first_op_index_bp, block_var_detail, current_device, False) # add send op send_grad_var_list = add_send_op( program, heter_block_bp, block_var_detail[stage_id - 1]["backward"]["persistables"]) # add step conter send_input_vars = [] dummy_output = [] pserver_endpoints = get_ps_endpoints(role_maker) attrs = { "message_to_block_id": grad_to_block_id, "optimize_blocks": optimizer_block, # runtime attribute "endpoint": get_heter_worker_endpoint(role_maker), "fanin": len(get_previous_stage_trainers(role_maker)), "pserver_id": get_role_id(role_maker), "distributed_mode": attrs['ps_mode'], "rpc_exec_thread_num": int(os.getenv("CPU_NUM", 32)), RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE } # append the listen_and_serv op heter_program.global_block().append_op(type="heter_listen_and_serv", inputs={'X': []}, outputs={}, attrs=attrs) # TODO check heter program def _apply_single_impl(self, main_program, startup_program, pass_ctx): """ split heter worker program from origin-program 1. find heter op (located on different device) 2. find input&output of every heter-block 3. create heter worker program, add listen&serv op """ attrs = pass_ctx._attrs default_deveice = "cpu" program, heter_ops, _, program_block_ops = find_heter_ops( main_program, default_deveice) if len(heter_ops) == 0: warnings.warn( "Currently running in Heter Parameter Server mode, but no OP running on heterogeneous devices, Please check your code." ) main_program = program return program_block_ops = union_forward_gradient_op(program_block_ops) block_vars_detail = find_block_joints(program, program_block_ops, heter_ops) heter_program = framework.Program() self._create_heter_program(program, attrs, heter_program, program_block_ops, heter_ops, block_vars_detail) main_program = heter_program @register_pass("split_trainer_ops_pass") class SplitTrainerOpsPass(PassBase): def __init__(self): super(SplitTrainerOpsPass, self).__init__() def _check_self(self): return True def _check_conflict(self, other_pass): return True def _replace_ops_by_communicate_op(self, program, attrs, heter_block_index, ops_list, block_var_detail): all_op = program.global_block().ops start_op = ops_list[0] first_op_idx = -1 for op in all_op: if str(op) == str(start_op): first_op_idx = all_op.index(op) break assert first_op_idx != -1 delete_same_ops(program.global_block(), ops_list) entrance_var = [] role_maker = attrs['role_maker'] if heter_block_index == 1: next_heter_worker_endpoints = get_next_stage_trainers(role_maker) entrance_var = block_var_detail[heter_block_index]["forward"][ "entrance"] comm_info = get_communicate_var_info(program, heter_block_index + 1, entrance_var) program.global_block()._insert_op( index=first_op_idx, type="send_and_recv", inputs={"X": program.global_block().vars[entrance_var[0]]}, outputs={"Out": []}, attrs={ "mode": "forward", "send_var_name": entrance_var + ["microbatch_id"], "recv_var_name": [], "message_name": comm_info["block_input_var_name"], "next_endpoints": next_heter_worker_endpoints, "previous_endpoints": [], "trainer_id": get_role_id(role_maker), RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) return entrance_var def _remove_var_pair_by_grad(self, var_name, attrs): for index, pair in enumerate(attrs['merged_variables_pairs']): var = pair[0] var_grad = pair[1] if var_grad.merged_var.name == var_name: del attrs['merged_variables_pairs'][index] for index, pair in enumerate(attrs['merged_dense_pairs']): var = pair[0] var_grad = pair[1] if var_grad.merged_var.name == var_name: del attrs['merged_dense_pairs'][index] return for index, pair in enumerate(attrs['merged_sparse_pairs']): var = pair[0] var_grad = pair[1] if var_grad.merged_var.name == var_name: del attrs['merged_sparse_pairs'][index] return def _remove_trainer_send_op(self, program, attrs, heter_block_index, block_var_detail): # if trainer do FF->BP->SEND, it has follow vars: var, var@GRAD # if trainer only do SEND, it has one var: var@GRAD # Delete Send op ,if trainer doesn't has pair var (var<->var@GRAD) persistables = block_var_detail[heter_block_index]["forward"]["persistables"] + \ block_var_detail[heter_block_index]["backward"]["persistables"] need_remove_send_op = [] need_remove_grad_var = [] for op in find_send_op(program): input_list, _ = find_op_input_output(program, program.global_block(), op) for var_name in input_list: origin_var_name = var_name.split("@GRAD")[0] if origin_var_name in persistables: need_remove_send_op.append(op) need_remove_grad_var.append(var_name) need_remove_send_op = list(set(need_remove_send_op)) delete_ops(program.global_block(), need_remove_send_op) for grad_var_name in need_remove_grad_var: self._remove_var_pair_by_grad(grad_var_name, attrs) def _create_trainer_program(self, program, origin_program, attrs, program_block_ops_list, block_var_detail): # This function mainly includes the following contents: # 1. For every heter block in origin program # a) delete heter op and related variables # b) add send&recv op # c) add communicate ops as follows: # origin_var -> reshape -> concat -> joint_var.0_1 # send&recv op(send joint_var.0_1; recv joint_var.1_2) # joint_var.1_2 -> slice -> reshape -> origin_var # d) remove send op which related var@grad is not in trainer program # 2. check every op's device static_var = [] for heter_block_index in range(1, len(program_block_ops_list)): ops_list = program_block_ops_list[heter_block_index][ "forward"] + program_block_ops_list[heter_block_index][ "backward"] static_var += self._replace_ops_by_communicate_op( program, attrs, heter_block_index, ops_list, block_var_detail) self._remove_trainer_send_op(program, attrs, heter_block_index, block_var_detail) optimizer_block = [] grad_to_block_id = [] bp_ops_list = program_block_ops_list[0]["backward"] delete_same_ops(program.global_block(), bp_ops_list) delete_trainer_useless_var(program, static_var) backward_block = create_backward_block(program, origin_program, bp_ops_list, block_var_detail) bp_entrance_vars = block_var_detail[0]["backward"]["entrance"] backward_comm_info = get_communicate_var_info(origin_program, 1, bp_entrance_vars, type="backward") grad_to_block_id.append(backward_comm_info["block_input_var_name"] + ":" + str(backward_block.idx)) optimizer_block.append(backward_block) role_maker = attrs['role_maker'] attrs = { "message_to_block_id": grad_to_block_id, "optimize_blocks": optimizer_block, # runtime attribute "endpoint": get_trainer_endpoint(role_maker), ## get trainer endpoint "fanin": 0, ## get heter worker "pserver_id": get_role_id(role_maker), "distributed_mode": attrs['ps_mode'], "rpc_exec_thread_num": int(os.getenv("CPU_NUM", 32)), RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE } # append the listen_and_serv op program.global_block()._insert_op(index=0, type="heter_listen_and_serv", inputs={'X': []}, outputs={}, attrs=attrs) ## TODO add check for bp block #check_op_device(program.global_block(), DEFAULT_DEVICE) def _apply_single_impl(self, main_program, startup_program, pass_ctx): """ split cpu-trainer program from origin-program 1. find heter op (located on different device) 2. find input&output of every heter-block 3. create cpu-trainer program, add send&recv op """ attrs = pass_ctx._attrs default_device_ = 'cpu' program, heter_ops, default_ops, program_block_ops = find_heter_ops( main_program, default_device_) program_block_ops = union_forward_gradient_op(program_block_ops) block_vars_detail = find_block_joints(program, program_block_ops, heter_ops) trainer_program = program.clone() self._create_trainer_program(trainer_program, program, attrs, program_block_ops, block_vars_detail) main_program = trainer_program @register_pass("set_heter_pipeline_opt_pass") class SetHeterPipelineOptPass(PassBase): def __init__(self): super(SetHeterPipelineOptPass, self).__init__() def _check_self(self): return True def _check_conflict(self, other_pass): return True def _apply_single_impl(self, main_program, startup_program, pass_ctx): attrs = pass_ctx._attrs role_maker = attrs['role_maker'] num_microbatches = attrs['user_defined_strategy'].pipeline_configs[ 'accumulate_steps'] startup_program._heter_pipeline_opt = { "startup_program": startup_program, "pipeline_stage": int(role_maker._get_stage_id()) - 1, "heter_place": role_maker._heter_device(), "is_fl_mode": 1 } main_program._heter_pipeline_opt = { "trainer": "HeterPipelineTrainer", "device_worker": "HeterSection", "trainers": role_maker._get_stage_trainers(), ## trainer num in each stage "trainer_id": int(role_maker._role_id()), "pipeline_stage": int(role_maker._get_stage_id()) - 1, "num_pipeline_stages": int(role_maker._get_num_stage()), "section_program": main_program, "num_microbatches": num_microbatches, "heter_place": role_maker._heter_device(), "is_fl_mode": 1 } @register_pass("split_fl_ops_pass") class SplitFlOpsPass(PassBase): def __init__(self): super(SplitFlOpsPass, self).__init__() self.PART_A_DEVICE_FlAG = 'gpu:0' self.PART_A_JOINT_OP_DEVICE_FlAG = 'gpu:2' self.PART_B_DEVICE_FlAG = 'gpu:1' self.PART_B_JOINT_OP_DEVICE_FlAG = 'gpu:3' def _check_self(self): return True def _check_conflict(self, other_pass): return True def _insert_encrypt_op(self): pass def _insert_decrypt_op(self): pass def _clear_op_device_flag(self, program): for block in program.blocks: for op in block.ops: device = op.attr(OP_DEVICE_KEY) op._set_attr(OP_DEVICE_KEY, '') if device != '' else None def _split_fl_program(self): self.partA_ops = [] self.partB_ops = [] party_program_map = defaultdict(Program) block = self.ori_main_program.block(0) for op in block.ops: device = op.attr(OP_DEVICE_KEY) if device == self.PART_A_DEVICE_FlAG or device == '' or device == self.PART_A_JOINT_OP_DEVICE_FlAG: program = party_program_map['a'] self.partA_ops.append(op) elif device == self.PART_B_DEVICE_FlAG or device == self.PART_B_JOINT_OP_DEVICE_FlAG: program = party_program_map['b'] self.partB_ops.append(op) op_desc = op.desc ap_op = program.global_block().desc.append_op() ap_op.copy_from(op_desc) ap_op._set_attr(OP_DEVICE_KEY, device) for key in ['a', 'b']: program = party_program_map[key] program._sync_with_cpp() return party_program_map def _insert_partA_communicate_op(self, block, idx): comm_info = "forward_joint_{}_{}@fl_ps".format(1, 2) block._insert_op( idx, type='send_and_recv', inputs={'X': self.partA_to_partB_tensor}, outputs={'Out': []}, attrs={ 'mode': 'forward', # mode 直接关联前向和反向 channel 选择 'send_var_name': self.partA_to_partB_tensor_name + ["microbatch_id"], 'recv_var_name': [], 'message_name': comm_info, 'next_endpoints': get_next_stage_trainers(self.role_maker), # partB_endpoints 'previous_endpoints': get_previous_stage_trainers(self.role_maker), 'trainer_id': get_role_id(self.role_maker), # global id RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) return def _insert_partB_communicate_op(self, block, idx): comm_info = ("backward_joint_{}_{}@fl_ps".format(2, 1)) block._insert_op( idx, type='send_and_recv', inputs={'X': self.partB_to_partA_grad}, outputs={'Out': []}, attrs={ 'mode': 'backward', 'send_var_name': self.partB_to_partA_grad_name + ["microbatch_id"], 'recv_var_name': [], 'message_name': comm_info, 'next_endpoints': get_next_stage_trainers(self.role_maker), # partA_endpoints 'previous_endpoints': get_previous_stage_trainers(self.role_maker), 'trainer_id': get_role_id(self.role_maker), # global id RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) return def _create_var_for_block(self, vars, block): for var in vars: if block._find_var_recursive(str(var)): continue source_var = self.ori_main_block._var_recursive(str(var)) if isinstance(var, Parameter): dest_var = block.create_parameter( name=source_var.name, shape=source_var.shape, dtype=source_var.dtype, type=source_var.type, lod_level=source_var.lod_level, stop_gradient=source_var.stop_gradient, trainable=source_var.trainable, optimize_attr=source_var.optimize_attr, regularizer=source_var.regularizer, error_clip=source_var.error_clip) else: dest_var = block._clone_variable(source_var, False) dest_var.stop_gradient = source_var.stop_gradient if hasattr(source_var, 'is_distributed'): dest_var.is_distributed = source_var.is_distributed def _get_block_by_idx(self, op_list, program, block_idx): if block_idx < len(program.blocks): new_block = program.block(block_idx) else: new_block = program._create_block() for _, op in enumerate(op_list): ap_op = new_block.desc.append_op() ap_op.copy_from(op.desc) ap_op._set_attr(OP_DEVICE_KEY, op.attr(OP_DEVICE_KEY)) vars = op.desc.input_arg_names() + op.desc.output_arg_names() self._create_var_for_block(vars, new_block) new_block._sync_with_cpp() return new_block def _find_joint_forward_op(self, block, flag): op_idx = 0 for op in block.ops: if is_forward_op(op) and op.attr(OP_DEVICE_KEY) == flag: return op_idx else: op_idx += 1 return op_idx def _find_joint_backward_op(self, block, flag): op_idx = 0 for op in block.ops: if is_backward_op(op) and op.attr(OP_DEVICE_KEY) == flag: return op_idx else: op_idx += 1 return op_idx def _get_partB_to_partA_grad(self, block, flag): op_idx = self._find_joint_backward_op(block, flag) op = block.ops[op_idx] vars1 = op.desc.input_arg_names() op_idx = self._find_joint_forward_op(block, flag) op = block.ops[op_idx] vars2 = op.desc.output_arg_names() self.partB_to_partA_grad_name = list(set(vars1) - set(vars2)) self.partB_to_partA_grad = [] for var_name in self.partB_to_partA_grad_name: self.partB_to_partA_grad.append(self.ori_main_block.var(var_name)) def _find_dense_grad_vars(self, bp_op_list): program = self.ori_main_program bp_op_input, bp_op_output = find_ops_list_input_output( program, bp_op_list) return (screen_persistables(program, bp_op_input) + screen_persistables(program, bp_op_output)) def _get_partA_program(self, block): # 1. create block 0 # 1.1 insert send op op_idx = self._find_joint_forward_op(block, self.PART_A_JOINT_OP_DEVICE_FlAG) op_list = [] for i in range(len(block.ops)): op = block.ops[i] op_list.append(op) if i == op_idx: out_name = op.desc.output_arg_names()[0] self.partA_to_partB_tensor_name = op.desc.output_arg_names() self.partA_to_partB_tensor = self.ori_main_block.var(out_name) break first_block = self._get_block_by_idx(op_list, self.partA_program, 0) self._insert_partA_communicate_op(first_block, op_idx + 1) # logger.info('partA-first_block:{}'.format(first_block)) # 2. create block 1 bp_op_list = get_bp_op_list(block) push_sparse_op_list = get_distributed_push_sparse_op_list(block) # logger.info('bp_op_list: {}'.format(bp_op_list)) second_block = self._get_block_by_idx(bp_op_list + push_sparse_op_list, self.partA_program, 1) # 2.1. insert partA recv op block_input_flag = "backward_joint_{}_{}@fl_ps".format(2, 1) grad_to_block_id = block_input_flag + ":" + str(second_block.idx) attrs = { "message_to_block_id": [grad_to_block_id], "optimize_blocks": [second_block], "endpoint": get_trainer_endpoint(self.role_maker), ## "fanin": 0, "pserver_id": get_role_id(self.role_maker), "distributed_mode": self.ps_mode, "rpc_exec_thread_num": int(os.getenv("CPU_NUM", 32)), RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE } second_block._insert_op(index=0, type='heter_listen_and_serv', inputs={'X': []}, outputs={}, attrs=attrs) # 2.2 insert push dense grad op send_ops = find_send_op(self.ori_main_program) # push dense delete_same_ops(block, send_ops) dense_grad_vars = self._find_dense_grad_vars(bp_op_list) add_send_op(self.ori_main_program, second_block, dense_grad_vars) # logger.info('partA-second_block:{}'.format(second_block)) def _get_partB_program(self, block): op_idx1 = self._find_joint_forward_op( block, self.PART_B_JOINT_OP_DEVICE_FlAG) # elementwise_add op op_idx2 = self._find_joint_backward_op(block, self.PART_B_JOINT_OP_DEVICE_FlAG) op_cnt = 0 op_list1 = [] op_list2 = [] op_list3 = [] for op in block.ops: if op_cnt < op_idx1: op_list1.append(op) elif op_cnt <= op_idx2: op_list2.append(op) else: op_list3.append(op) op_cnt += 1 # 1. create block 0 first_block = self._get_block_by_idx(op_list1, self.partB_program, 0) # 2. create block 1 second_block = self._get_block_by_idx(op_list2, self.partB_program, 1) # 2.1 insert send op self._insert_partB_communicate_op(second_block, len(op_list2)) # 2.2 insert remain ops second_block = self._get_block_by_idx(op_list3, self.partB_program, 1) # 2.3 insert push dense grad op bp_op_list = get_bp_op_list(second_block) dense_grad_vars = self._find_dense_grad_vars(bp_op_list) add_send_op(self.ori_main_program, second_block, dense_grad_vars) # 3. insert partB recv op block_input_flag = "forward_joint_{}_{}@fl_ps".format(1, 2) grad_to_block_id = block_input_flag + ":" + str(second_block.idx) attrs = { "message_to_block_id": [grad_to_block_id], "optimize_blocks": [second_block], ## what to do? "endpoint": get_heter_worker_endpoint(self.role_maker), "fanin": len(get_previous_stage_trainers(self.role_maker)), "pserver_id": 1, # TODO "distributed_mode": self.ps_mode, "rpc_exec_thread_num": int(os.getenv("CPU_NUM", 32)), RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE } first_block._insert_op(index=len(op_list1), type="heter_listen_and_serv", inputs={'X': []}, outputs={}, attrs=attrs) #logger.info('partB-first_block:{}'.format(first_block)) #logger.info('partB-second_block:{}'.format(second_block)) def _apply_single_impl(self, main_program, startup_program, pass_ctx): attrs = pass_ctx._attrs self.role_maker = attrs['role_maker'] self.ps_mode = attrs['ps_mode'] self.is_part_b = attrs['is_heter_worker'] # TODO self.ori_main_program = main_program self.ori_main_block = main_program.block(0) party_program_map = self._split_fl_program() prog_a = party_program_map['a'] _main_file = ps_log_root_dir + '6_fl_A_main_program.prototxt' debug_program(_main_file, prog_a) self._get_partB_to_partA_grad(prog_a.global_block(), self.PART_A_JOINT_OP_DEVICE_FlAG) prog_b = party_program_map['b'] _main_file = ps_log_root_dir + '6_fl_B_main_program.prototxt' debug_program(_main_file, prog_b) if not self.is_part_b: self.partA_program = framework.Program() self._get_partA_program(prog_a.global_block()) pass_ctx._attrs['part_a_main_program'] = self.partA_program self._clear_op_device_flag(self.partA_program) check_program(self.partA_program) else: self.partB_program = framework.Program() self._get_partB_program(prog_b.global_block()) pass_ctx._attrs['part_b_main_program'] = self.partB_program self._clear_op_device_flag(self.partB_program) check_program(self.partB_program)