diff --git a/python/paddle/distributed/fleet/utils/ps_util.py b/python/paddle/distributed/fleet/utils/ps_util.py index 8bf69a41a7cc839d9dadf724a00f87e934425ac1..ba6fd54a60a5e660fe91b7363ba4de09cd9e899f 100644 --- a/python/paddle/distributed/fleet/utils/ps_util.py +++ b/python/paddle/distributed/fleet/utils/ps_util.py @@ -128,9 +128,113 @@ class DistributedInfer: return pull_sparse_ops def _pull_sparse_fuse(_program, pull_sparse_ops): + def dag_check_up_and_reorder(program, inputs, outputs): + global_block = program.global_block() + min_output_index = len(global_block.ops) + max_input_index = -1 + input_indexes = [0] * len(global_block.ops) + output_indexes = [0] * len(global_block.ops) + for idx, op in enumerate(global_block.ops): + for i in range(0, len(op.output_names)): + if input_indexes[idx] == 1: + break + outs = op.output(op.output_names[i]) + for in_id, in_var in enumerate(inputs): + if in_var.name in outs: + input_indexes[idx] = 1 + max_input_index = max(max_input_index, idx) + break + + for i in range(0, len(op.input_names)): + if output_indexes[idx] == 1: + break + ins = op.input(op.input_names[i]) + for out_id, out_var in enumerate(outputs): + if out_var.name in ins: + output_indexes[idx] = 1 + min_output_index = min(min_output_index, + idx) + + for i in range(len(global_block.ops)): + if input_indexes[i] == 1 and output_indexes[i] == 1: + warnings.warn( + "unable to re-arrange dags order to combine distributed embedding ops because a op both needs embedding table's output as input and produces ids as the same embedding table's input" + ) + return + + if min_output_index < max_input_index: + move_ops = [] + for i in range(min_output_index + 1, + len(input_indexes)): + if input_indexes[i] == 1: + move_ops.append((global_block.ops[i], i)) + for i, op in enumerate(move_ops): + queue = list() + visited = set() + queue.append(op[1]) + visited.add(op[0]) + start = 0 + while start < len(queue): + pos = queue[start] + op = global_block.ops[pos] + op_inputs = [] + for k in range(0, len(op.input_names)): + ins = op.input(op.input_names[k]) + op_inputs.append(ins) + for j in range(pos - 1, min_output_index - 1, + -1): + op1 = global_block.ops[j] + if op1 in visited: + continue + found = False + for k in range(0, len(op1.output_names)): + outs = op1.output(op1.output_names[k]) + for t in range(len(op_inputs)): + for y in op_inputs[t]: + if y in outs: + found = True + break + if found: + break + if found: + break + if found: + if output_indexes[j] == True: + warnings.warn( + "unable to re-arrange dags order to combine distributed embedding ops" + ) + return + queue.append(j) + visited.add(global_block.ops[j]) + start = start + 1 + + queue.sort() + for index in queue: + desc = global_block.desc._insert_op( + min_output_index) + desc.copy_from(global_block.ops[index].desc) + global_block.desc._remove_op(index + 1, + index + 2) + global_block.ops[index].desc = desc + insert_op = global_block.ops.pop(index) + input_state = input_indexes.pop(index) + output_state = output_indexes.pop(index) + global_block.ops.insert(min_output_index, + insert_op) + input_indexes.insert(min_output_index, + input_state) + output_indexes.insert(min_output_index, + output_state) + min_output_index = min_output_index + 1 + + assert global_block.desc.op_size() == len( + global_block.ops) + for i in range(len(global_block.ops)): + assert global_block.desc.op(i) == global_block.ops[ + i].desc + for param, ops in pull_sparse_ops.items(): all_ops = program.global_block().ops - op_idxs = [all_ops.index(op) for op in ops] inputs = [ program.global_block().vars[op.input("Ids")[0]] @@ -155,23 +259,29 @@ class DistributedInfer: for op in ops ] + dag_check_up_and_reorder(program, inputs, outputs) + op_idxs = [all_ops.index(op) for op in ops] + for idx in op_idxs[::-1]: program.global_block()._remove_op(idx) inputs_idxs = [-1] * len(inputs) - outputs_idxs = [-1] * len(outputs) + outputs_idxs = [len(program.global_block().ops) + 1] * len( + outputs) for idx, op in enumerate(program.global_block().ops): for i in range(0, len(op.output_names)): outs = op.output(op.output_names[i]) for in_id, in_var in enumerate(inputs): if in_var.name in outs: - inputs_idxs[in_id] = idx + inputs_idxs[in_id] = max(idx, + inputs_idxs[in_id]) for i in range(0, len(op.input_names)): ins = op.input(op.input_names[i]) for out_id, out_var in enumerate(outputs): if out_var.name in ins: - outputs_idxs[out_id] = idx + outputs_idxs[out_id] = min( + idx, outputs_idxs[out_id]) if min(outputs_idxs) - max(inputs_idxs) >= 1: distributed_idx = max(inputs_idxs) + 1 diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py index 89b2a8237dc65ab8ebd6b145c878e9da5501946d..2874949e3c9b868532fb0bb34a1fb365536994da 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py @@ -111,9 +111,104 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): return pull_sparse_ops def _pull_sparse_fuse(_program, pull_sparse_ops, use_ps_gpu): + def dag_check_up_and_reorder(program, inputs, outputs): + global_block = program.global_block() + min_output_index = len(global_block.ops) + max_input_index = -1 + input_indexes = [0] * len(global_block.ops) + output_indexes = [0] * len(global_block.ops) + for idx, op in enumerate(global_block.ops): + for i in range(0, len(op.output_names)): + if input_indexes[idx] == 1: + break + outs = op.output(op.output_names[i]) + for in_id, in_var in enumerate(inputs): + if in_var.name in outs: + input_indexes[idx] = 1 + max_input_index = max(max_input_index, idx) + break + + for i in range(0, len(op.input_names)): + if output_indexes[idx] == 1: + break + ins = op.input(op.input_names[i]) + for out_id, out_var in enumerate(outputs): + if out_var.name in ins: + output_indexes[idx] = 1 + min_output_index = min(min_output_index, idx) + + for i in range(len(global_block.ops)): + if input_indexes[i] == 1 and output_indexes[i] == 1: + warnings.warn( + "unable to re-arrange dags order to combine distributed embedding ops because a op both needs embedding table's output as input and produces ids as the same embedding table's input" + ) + return + + if min_output_index < max_input_index: + move_ops = [] + for i in range(min_output_index + 1, len(input_indexes)): + if input_indexes[i] == 1: + move_ops.append((global_block.ops[i], i)) + for i, op in enumerate(move_ops): + queue = list() + visited = set() + queue.append(op[1]) + visited.add(op[0]) + start = 0 + while start < len(queue): + pos = queue[start] + op = global_block.ops[pos] + op_inputs = [] + for k in range(0, len(op.input_names)): + ins = op.input(op.input_names[k]) + op_inputs.append(ins) + for j in range(pos - 1, min_output_index - 1, -1): + op1 = global_block.ops[j] + if op1 in visited: + continue + found = False + for k in range(0, len(op1.output_names)): + outs = op1.output(op1.output_names[k]) + for t in range(len(op_inputs)): + for y in op_inputs[t]: + if y in outs: + found = True + break + if found: + break + if found: + break + if found: + if output_indexes[j] == True: + warnings.warn( + "unable to re-arrange dags order to combine distributed embedding ops" + ) + return + queue.append(j) + visited.add(global_block.ops[j]) + start = start + 1 + + queue.sort() + for index in queue: + desc = global_block.desc._insert_op(min_output_index) + desc.copy_from(global_block.ops[index].desc) + global_block.desc._remove_op(index + 1, index + 2) + global_block.ops[index].desc = desc + insert_op = global_block.ops.pop(index) + input_state = input_indexes.pop(index) + output_state = output_indexes.pop(index) + global_block.ops.insert(min_output_index, insert_op) + input_indexes.insert(min_output_index, input_state) + output_indexes.insert(min_output_index, output_state) + min_output_index = min_output_index + 1 + + assert global_block.desc.op_size() == len(global_block.ops) + for i in range(len(global_block.ops)): + assert global_block.desc.op(i) == global_block.ops[i].desc + for param, ops in pull_sparse_ops.items(): all_ops = program.global_block().ops - op_idxs = [all_ops.index(op) for op in ops] + inputs = [ program.global_block().vars[op.input("Ids")[0]] for op in ops ] @@ -139,23 +234,28 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): program.global_block().vars[op.output("Out")[0]] for op in ops ] + dag_check_up_and_reorder(program, inputs, outputs) + + op_idxs = [all_ops.index(op) for op in ops] + for idx in op_idxs[::-1]: program.global_block()._remove_op(idx) inputs_idxs = [-1] * len(inputs) - outputs_idxs = [-1] * len(outputs) + outputs_idxs = [len(program.global_block().ops) + 1] * len(outputs) for idx, op in enumerate(program.global_block().ops): for i in range(0, len(op.output_names)): outs = op.output(op.output_names[i]) for in_id, in_var in enumerate(inputs): if in_var.name in outs: - inputs_idxs[in_id] = idx + inputs_idxs[in_id] = max(idx, inputs_idxs[in_id]) for i in range(0, len(op.input_names)): ins = op.input(op.input_names[i]) for out_id, out_var in enumerate(outputs): if out_var.name in ins: - outputs_idxs[out_id] = idx + outputs_idxs[out_id] = min(idx, + outputs_idxs[out_id]) if min(outputs_idxs) - max(inputs_idxs) >= 1: distributed_idx = max(inputs_idxs) + 1 @@ -187,7 +287,7 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): }) else: for i in range(len(inputs_idxs)): - distributed_idx = op_idxs[i] + 1 + distributed_idx = op_idxs[i] program.global_block()._insert_op( index=distributed_idx, @@ -557,7 +657,6 @@ def find_heter_ops(program, default_device="cpu"): def create_heter_program(program, config, heter_program, heter_ops, block_var_detail, current_device): - # This function mainly includes the following contents: # 1. For every heter block: # a) copy heter device op from origin program @@ -1029,7 +1128,6 @@ def insert_send_concat_op(program, block, index, var_name_list, new_var_name, def insert_recv_slice_op(program, block, index, var_name, var_shape, dtype, type, new_var_name_list, new_var_shape_list): - if var_name not in program.global_block().vars: input_var = program.global_block().create_var( name=var_name, shape=var_shape, dtype=dtype, type=type)