diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index 4cc710b226d8f84fadd249a148e754d5330fb564..c6afcfec8a0082bc2a20e88275d5ba428fa3aaf1 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -29,6 +29,7 @@ from .process_group import new_process_group, ProcessGroup, _g_process_group_map # NOTE: If op in _g_special_ops, it will not be resharded. _g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling'] +while_block_info = {} class AllGatherOpDesc: @@ -280,8 +281,20 @@ def _is_overlapped(shape_x, shape_y): return overlapped -def _need_reshard(dist_tensor, dist_op, op_input=True): +def _need_reshard(dist_tensor, + dist_op, + actual_process_mesh, + program, + dist_context, + op_input=True): """Judge the tensor whether needs to be resharded.""" + + def _is_unshard(dims_mapping): + for dim in dims_mapping: + if dim != -1: + return False + return True + is_reshard = False tensor_dist_attr = dist_tensor.dist_attr tensor_name = dist_tensor.serial_tensor.name @@ -289,32 +302,74 @@ def _need_reshard(dist_tensor, dist_op, op_input=True): tensor_process_mesh = tensor_dist_attr.process_mesh op_dist_attr = dist_op.dist_attr op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name) - op_process_mesh = op_dist_attr.process_mesh + op_process_mesh = actual_process_mesh if op_input: op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name) - op_process_mesh = op_dist_attr.process_mesh if all( map(lambda x: x is not None, [ tensor_dims_mapping, tensor_process_mesh, op_input_dims_mapping, op_process_mesh ])): - if tensor_dims_mapping != op_input_dims_mapping or tensor_process_mesh != op_process_mesh: - is_reshard = True + # dims_mapping + if tensor_dims_mapping != op_input_dims_mapping: + if dist_op.serial_op.type == "while": + sub_block = program.blocks[dist_op.serial_op.attr( + "sub_block").id] + for op in sub_block.ops: + for var_name in op.input_arg_names: + if var_name == tensor_name: + dist_op_attr = dist_context.get_dist_op_for_program( + op).dist_attr + var_dims_mapping = dist_op_attr.get_input_dims_mapping( + var_name) + if var_dims_mapping != tensor_dims_mapping: + is_reshard = True + break + else: + is_reshard = True + # process_mesh + if tensor_process_mesh != op_process_mesh: + # when processes length is not the same, the dims mapping must be replicative now + if len(tensor_process_mesh.processes) != len( + op_process_mesh.processes): + assert _is_unshard(tensor_dims_mapping) + assert _is_unshard(op_input_dims_mapping) + else: + if dist_tensor.serial_tensor.dtype == paddle.bool: + raise ValueError("Bool var is not supported reshard.") + + # for while op, it should find the process mesh of op actually used the tensor as input + if dist_op.serial_op.type == "while": + sub_block = program.blocks[dist_op.serial_op.attr( + "sub_block").id] + for op in sub_block.ops: + for var_name in op.input_arg_names: + if var_name == tensor_name: + dist_op_attr = dist_context.get_dist_op_for_program( + op).dist_attr + process_mesh = dist_op_attr.process_mesh + if process_mesh == op_process_mesh: + is_reshard = True + break + else: + is_reshard = True else: op_output_dims_mapping = op_dist_attr.get_output_dims_mapping( tensor_name) - op_process_mesh = op_dist_attr.process_mesh if all( map(lambda x: x is not None, [ tensor_dims_mapping, tensor_process_mesh, op_output_dims_mapping, op_process_mesh ])): if tensor_process_mesh != op_process_mesh: + if dist_tensor.serial_tensor.dtype == paddle.bool: + raise ValueError("Bool var is not supported reshard.") is_reshard = True if tensor_dims_mapping != op_output_dims_mapping: raise ValueError( "It is not supported that tensor dims mapping is different from op output dims mapping." ) + return is_reshard @@ -329,13 +384,14 @@ def _compute_complete_shape(slice_shape, process_shape, dims_mapping): return complete_shape -def find_op_desc_seq(dist_tensor, dist_op): +def find_op_desc_seq(dist_tensor, dist_op, actual_process_mesh, batch_size): """ Find the op description sequence to reshard the source tensor for matching the op requirement. Args: dist_tensor (DistributedTensor): A distributed tensor. dist_op (DistributedOperator): A distributed operator. + actual_process_mesh (ProcessMesh): The actual op process mesh. Returns: Dict, the dict represents the required op description sequence corresponding to process, The key of dict is @@ -350,11 +406,16 @@ def find_op_desc_seq(dist_tensor, dist_op): source_process_shape = source_process_mesh.topology op_dist_attr = dist_op.dist_attr - target_process_mesh = op_dist_attr.process_mesh + target_process_mesh = actual_process_mesh target_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name) target_process_group = target_process_mesh.processes target_process_shape = target_process_mesh.topology + if source_tensor.shape[0] < 0: + new_shape = list(source_tensor.shape) + new_shape[0] = batch_size + source_tensor.desc.set_shape(new_shape) + complete_shape = _compute_complete_shape( source_tensor.shape, source_process_shape, source_dims_mapping) op_desc_seq = {} @@ -503,7 +564,7 @@ def find_op_desc_seq(dist_tensor, dist_op): return op_desc_seq -def _insert_send_op(block, idx, tensor, dst): +def _insert_send_op(block, idx, tensor, dst, op_role): """Insert send op into block at the given index.""" op_type = 'send_v2' block._insert_op( @@ -514,10 +575,11 @@ def _insert_send_op(block, idx, tensor, dst): 'ring_id': 0, 'peer': dst, 'use_calc_stream': True, + 'op_role': op_role }) -def _insert_recv_op(block, idx, tensor, src): +def _insert_recv_op(block, idx, tensor, src, op_role): """Insert recv op into block at the given index.""" op_type = 'recv_v2' block._insert_op( @@ -531,14 +593,16 @@ def _insert_recv_op(block, idx, tensor, src): 'out_shape': tensor.shape, 'dtype': tensor.dtype, 'use_calc_stream': True, + 'op_role': op_role }) -def _insert_concat_op(block, idx, tensors, axis): +def _insert_concat_op(block, idx, tensors, axis, op_role): """Insert concat op into block at the given block.""" inputs = {'X': tensors} attrs = {} attrs['axis'] = axis + attrs['op_role'] = op_role helper = LayerHelper('concat', **locals()) with paddle.static.program_guard(block.program): out = helper.create_variable_for_type_inference( @@ -548,7 +612,8 @@ def _insert_concat_op(block, idx, tensors, axis): return out -def _insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name): +def _insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name, + op_role): """Insert slice op into block at the given block.""" inputs = {'Input': tensor} infer_flags = list(1 for i in range(len(axes))) @@ -556,24 +621,23 @@ def _insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name): "axes": axes, "starts": starts, "ends": ends, - "infer_flags": infer_flags + "infer_flags": infer_flags, + 'op_role': op_role } helper = LayerHelper('slice', **locals()) out = block.create_var( - name=new_var_name, - dtype=tensor.dtype, - type=core.VarDesc.VarType.LOD_TENSOR) + name=new_var_name, dtype=tensor.dtype, type=tensor.type) block._insert_op( idx, type="slice", inputs=inputs, outputs={'Out': [out]}, attrs=attrs) return out -def _insert_split_op(block, idx, tensor, num_or_sections): +def _insert_split_op(block, idx, tensor, num_or_sections, op_role): """Insert split op into block at the given index.""" helper = LayerHelper('split', **locals()) input_shape = tensor.shape inputs = {'X': tensor} - attrs = {'num': num_or_sections, "axis": 0} + attrs = {'num': num_or_sections, 'axis': 0, 'op_role': op_role} with paddle.static.program_guard(block.program): outs = [ helper.create_variable_for_type_inference( @@ -584,7 +648,7 @@ def _insert_split_op(block, idx, tensor, num_or_sections): return outs -def _insert_allgather_op(block, idx, tensor, ranks): +def _insert_allgather_op(block, idx, tensor, ranks, op_role): """Insert allgather op into block at the given index.""" def _insert_fill_constant_op(block, idx): @@ -597,6 +661,7 @@ def _insert_allgather_op(block, idx, tensor, ranks): attrs['str_value'] = str(int("1")) attrs['value'] = int("1") attrs['dtype'] = out.dtype + attrs['op_role'] = op_role utils.get_shape_tensor_inputs( inputs=inputs, attrs=attrs, shape=[0], op_type='fill_constant') block._insert_op( @@ -625,14 +690,16 @@ def _insert_allgather_op(block, idx, tensor, ranks): inputs={'X': [fill_constant_out]}, outputs={'Out': [fill_constant_out]}, attrs={'ring_id': 0, - 'use_calc_stream': True}) + 'use_calc_stream': True, + 'op_role': op_role}) # insert c_sync_calc_stream op block._insert_op( idx + 2, type="c_sync_calc_stream", inputs={'X': [fill_constant_out]}, - outputs={'Out': [fill_constant_out]}) + outputs={'Out': [fill_constant_out]}, + attrs={'op_role': op_role}) idx_offset = 3 # insert c_allgather op @@ -649,20 +716,21 @@ def _insert_allgather_op(block, idx, tensor, ranks): attrs={ 'ring_id': group.id, 'use_calc_stream': True, - 'nranks': group.nranks + 'nranks': group.nranks, + 'op_role': op_role }) idx_offset += 1 # insert split op split_out = _insert_split_op(block, idx + idx_offset, allgather_out, - group.nranks) + group.nranks, op_role) idx_offset += 1 tensor_list.extend(split_out) return tensor_list, idx_offset def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index, - block, idx): + block, idx, op_role): """Concat the tensors and insert concat op.""" if not partition_tensor_list: partition_tensor_list.append((tensor, partition_index)) @@ -674,13 +742,13 @@ def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index, partition_tensor_list[i][1], partition_index) if concat_axis != -1: has_concat = True - _ = _insert_concat_op(block, idx[0], [partition_tensor_list[i][0], tensor], concat_axis) \ + _ = _insert_concat_op(block, idx[0], [partition_tensor_list[i][0], tensor], concat_axis, op_role) \ if first_order == 0 else \ - _insert_concat_op(block, idx[0], [tensor, partition_tensor_list[i][0]], concat_axis) + _insert_concat_op(block, idx[0], [tensor, partition_tensor_list[i][0]], concat_axis, op_role) partition_tensor_list.pop(i) idx[0] += 1 _concat_partitions_with_op(partition_tensor_list, _, - new_partition, block, idx) + new_partition, block, idx, op_role) break i += 1 if not has_concat: @@ -692,8 +760,47 @@ HAS_RECV = {} HAS_ALLGATHER = {} -def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, - dist_context): +def _get_while_op_actual_process_mesh(op, program, rank_id, dist_context): + """Get the while op actual Process mesh corresponding to rank""" + assert op.type == "while" + while_op_process_mesh = dist_context.get_dist_op_for_program( + op).dist_attr.process_mesh + sub_block = program.blocks[op.attr("sub_block").id] + ops = sub_block.ops + actual_process_mesh = None + for op in ops: + dist_op = dist_context.get_dist_op_for_program(op) + if not dist_op: + continue + process_mesh = dist_op.dist_attr.process_mesh + if process_mesh == while_op_process_mesh: + continue + if rank_id in process_mesh.processes: + raw_process_mesh = process_mesh + break + + if actual_process_mesh is None and rank_id in while_op_process_mesh.processes: + actual_process_mesh = while_op_process_mesh + + assert actual_process_mesh is not None + return actual_process_mesh + + +def _get_var(var_name, block, program): + """Get var in the parent block if not found in the current block""" + var = None + if var_name in block.vars: + var = block.vars[var_name] + else: + parent_block = program.blocks[block.parent_idx] + if var_name in parent_block.vars: + var = parent_block.vars[var_name] + assert var is not None + return var + + +def parse_op_desc(block, rank_id, op_desc_seq, var_name, reshard_op, + dist_context, program, actual_process_mesh): """Parse op desc sequence and insert op in the block""" global HAS_SENT global HAS_RECV @@ -703,9 +810,6 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, if rank_id not in op_desc_seq.keys(): return op_desc_list = op_desc_seq[rank_id] - block = program.global_block() - assert var_name in block.vars.keys( - ), "The {} cannot be found in the {} program.".format(var_name, rank_id) idx = None for index, op in list(enumerate(block.ops)): @@ -716,7 +820,7 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, rank_id) matched_op = block.ops[idx] - source_tensor = block.vars[var_name] + source_tensor = _get_var(var_name, block, program) for op_desc in op_desc_list: if isinstance(op_desc, AllGatherOpDesc): # noqa: F401 if var_name not in HAS_ALLGATHER.keys(): @@ -724,7 +828,8 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, if not HAS_ALLGATHER[var_name] or op_desc.group not in list( map(lambda x: x[0], HAS_ALLGATHER[var_name])): tensor_list, idx_offset = _insert_allgather_op( - block, idx, source_tensor, op_desc.group) + block, idx, source_tensor, op_desc.group, + reshard_op.attr('op_role')) idx += idx_offset tensor_name_list = [var.name for var in tensor_list] HAS_ALLGATHER[var_name].append( @@ -743,7 +848,8 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, if var_name not in HAS_SENT.keys(): HAS_SENT[var_name] = [] if op_desc.dst not in HAS_SENT[var_name]: - _insert_send_op(block, idx, source_tensor, op_desc.dst) + _insert_send_op(block, idx, source_tensor, op_desc.dst, + reshard_op.attr('op_role')) idx += 1 HAS_SENT[var_name].append(op_desc.dst) @@ -758,8 +864,10 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, recv_tensor = block.create_var( name=unique_name.generate(var_name + "@recv"), shape=shape, - dtype=source_tensor.dtype) - _insert_recv_op(block, idx, recv_tensor, op_desc.src) + dtype=source_tensor.dtype, + type=source_tensor.type) + _insert_recv_op(block, idx, recv_tensor, op_desc.src, + reshard_op.attr('op_role')) tensor_list.append(recv_tensor) idx += 1 HAS_RECV[var_name][op_desc.src] = recv_tensor @@ -772,7 +880,7 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, for index, tensor in enumerate(tensor_list): _concat_partitions_with_op(partition_tensor_list, tensor, partition_index_list[index], block, - idx_list) + idx_list, reshard_op.attr('op_role')) idx = idx_list[0] elif isinstance(op_desc, SliceOpDesc): @@ -787,11 +895,11 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, starts=op_desc.starts, ends=op_desc.ends, axes=op_desc.axes, - new_var_name=new_name) + new_var_name=new_name, + op_role=reshard_op.attr('op_role')) tensor_attr = TensorDistributedAttribute() - process_mesh = dist_context.get_op_dist_attr_for_program( - matched_op).process_mesh + process_mesh = actual_process_mesh dims_mapping = dist_context.get_op_dist_attr_for_program( matched_op).get_input_dims_mapping(var_name) tensor_attr.dims_mapping = dims_mapping @@ -799,11 +907,29 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, dist_context.set_tensor_dist_attr_for_program(target_tensor, tensor_attr) + if op.type == "while": + global while_block_info + # var_reshard_mapping means the while op input need be changed to + if "var_reshard_mapping" not in while_block_info[op.attr( + "sub_block").id].keys(): + while_block_info[op.attr("sub_block").id][ + "var_reshard_mapping"] = {} + while_block_info[op.attr("sub_block").id][ + "var_reshard_mapping"][var_name] = target_tensor.name + # rename op input name according to new name for op in block.ops: for name in op.input_arg_names: op_dist_attr = dist_context.get_op_dist_attr_for_program(op) if name == var_name and op_dist_attr is not None: + if op.desc.id() == matched_op.desc.id(): + op.desc._rename_input(name, target_tensor.name) + op_dist_attr.set_input_dims_mapping( + target_tensor.name, dims_mapping) + op_dist_attr.set_input_dist_attr(name, None) + continue + + # NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation. op_process_mesh = op_dist_attr.process_mesh op_input_dims_mapping = op_dist_attr.get_input_dims_mapping( var_name) @@ -819,102 +945,166 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id): not_remove_op_ref = [ "create_py_reader", "create_double_buffer_reader", "read" ] - remove_op_idx = [] - block = auto_parallel_main_prog.global_block() - ops = block.ops - vars = block.vars - for idx, op in enumerate(ops): - # handle read op in the pipeline scene specially, it will be removed in the future. - if op.type == "read": - dim_list = [] - for var_name in op.output_arg_names: - dim_list.extend(vars[var_name].shape) - for i in range(idx, -1, -1): - if ops[i].type == "create_py_reader": - ops[i]._set_attr("shape_concat", dim_list) - break - continue - - # replace the input and output of c_sync_comm_stream op when in pipeline scene. - if op.type == "c_sync_comm_stream": - need_save = [] - for var_name in op.input_arg_names: - process_mesh = dist_context.get_tensor_dist_attr_for_program( - vars[var_name]).process_mesh - if rank_id in process_mesh.processes: - need_save.append(var_name) - if not need_save: - remove_op_idx.append(idx) + global while_block_info + + # NOTE: The nested sub block is not be supported now. + remove_block_order = [] + for block_idx in while_block_info: + remove_block_order.append(block_idx) + + for block_idx, block in enumerate(auto_parallel_main_prog.blocks): + if block_idx not in remove_block_order: + remove_block_order.append(block_idx) + + # the sub block should be removed first + for block_idx in remove_block_order: + remove_op_idx = [] + block = auto_parallel_main_prog.blocks[block_idx] + ops = block.ops + vars = block.vars + for idx, op in enumerate(ops): + if op.type == "read": + dim_list = [] + for var_name in op.output_arg_names: + dim_list.extend( + _get_var(var_name, block, auto_parallel_main_prog) + .shape) + for i in range(idx, -1, -1): + if ops[i].type == "create_py_reader": + ops[i]._set_attr("shape_concat", dim_list) + break continue - proto = OpProtoHolder.instance().get_op_proto(op.type) - op.desc.set_input(proto.inputs[0].name, need_save) - op.desc.set_output(proto.outputs[0].name, need_save) - continue + # replace the input and output of c_sync_comm_stream op when in pipeline scene. + if op.type == "c_sync_comm_stream": + need_save = [] + for var_name in op.input_arg_names: + process_mesh = dist_context.get_tensor_dist_attr_for_program( + _get_var(var_name, block, + auto_parallel_main_prog)).process_mesh + if rank_id in process_mesh.processes: + need_save.append(var_name) + if not need_save: + remove_op_idx.append(idx) + continue - # judge the other op whether should be removed. - op_dist_attr = dist_context.get_op_dist_attr_for_program(op) - if op_dist_attr is not None: - op_process_mesh = op_dist_attr.process_mesh - if rank_id not in op_process_mesh.processes and op.type not in not_remove_op_ref: - remove_op_idx.append(idx) + proto = OpProtoHolder.instance().get_op_proto(op.type) + op.desc.set_input(proto.inputs[0].name, need_save) + op.desc.set_output(proto.outputs[0].name, need_save) + continue - for idx in remove_op_idx[::-1]: - block._remove_op(idx) + # judge the other op whether should be removed. + op_dist_attr = dist_context.get_op_dist_attr_for_program(op) + if op_dist_attr is not None: + op_process_mesh = op_dist_attr.process_mesh + if rank_id not in op_process_mesh.processes and op.type not in not_remove_op_ref: + remove_op_idx.append(idx) + + for idx in remove_op_idx[::-1]: + block._remove_op(idx) def _remove_no_need_vars(auto_parallel_main_prog, dist_params_grads): """Remove no need vars in the main program""" - remove_vars = set() - block = auto_parallel_main_prog.global_block() - ops = block.ops - vars = block.vars - need_vars = set() - for op in ops: - for var_name in op.input_arg_names: - if var_name in vars: - need_vars.add(var_name) - for var_name in op.output_arg_names: - if var_name in vars: - need_vars.add(var_name) - for var in vars: - if var not in need_vars: - remove_vars.add(var) - - # change dist_params_grads - param_grad_map = {} - for op in ops: - if int(op.attr('op_role')) == int(OpRole.Optimize): - if "Param" in op.input_names and "Grad" in op.input_names: - param_name = op.input("Param")[0] - grad_name = op.input("Grad")[0] - param_grad_map[param_name] = grad_name - - need_remove_idx = [] - for idx, item in enumerate(dist_params_grads): - if item[0].name not in param_grad_map.keys(): - need_remove_idx.append(idx) - - for idx in need_remove_idx[::-1]: - dist_params_grads.pop(idx) - - idx = 0 - while idx < len(dist_params_grads): - param_name = dist_params_grads[idx][0].name - grad_name = dist_params_grads[idx][1].name - if grad_name != param_grad_map[param_name]: - dist_params_grads[idx] = (vars[param_name], - vars[param_grad_map[param_name]]) - idx += 1 + for block_idx, block in enumerate(auto_parallel_main_prog.blocks): + remove_vars = set() + ops = block.ops + vars = block.vars + need_vars = set() + for op in ops: + for var_name in op.input_arg_names: + if var_name in vars: + need_vars.add(var_name) + for var_name in op.output_arg_names: + if var_name in vars: + need_vars.add(var_name) + for var in vars: + if var not in need_vars: + remove_vars.add(var) + + # change dist_params_grads, the optimize op just in block 0. + if block_idx == 0: + param_grad_map = {} + for op in ops: + if int(op.attr('op_role')) == int(OpRole.Optimize): + if "Param" in op.input_names and "Grad" in op.input_names: + param_name = op.input("Param")[0] + grad_name = op.input("Grad")[0] + param_grad_map[param_name] = grad_name + + need_remove_idx = [] + for idx, item in enumerate(dist_params_grads): + if item[0].name not in param_grad_map.keys(): + need_remove_idx.append(idx) + + for idx in need_remove_idx[::-1]: + dist_params_grads.pop(idx) + + idx = 0 + while idx < len(dist_params_grads): + param_name = dist_params_grads[idx][0].name + grad_name = dist_params_grads[idx][1].name + if grad_name != param_grad_map[param_name]: + dist_params_grads[idx] = (vars[param_name], + vars[param_grad_map[param_name]]) + idx += 1 - for var in remove_vars: - block._remove_var(var) + for var in remove_vars: + block._remove_var(var) + + +def _change_while_op_input_and_output(auto_parallel_main_prog, dist_context): + """Change while op input and output after the corresponding sub block ops removed""" + global while_block_info + for sub_block_idx in while_block_info: + sub_block = auto_parallel_main_prog.blocks[sub_block_idx] + parent_while_op_id = while_block_info[sub_block_idx]["op_id"] + parent_block = auto_parallel_main_prog.blocks[sub_block.parent_idx] + + sub_block_op_inputs = set() + sub_block_op_outputs = [] + for op in sub_block.ops: + # skip the input and output of operators inserted in the reshard phase + dist_op = dist_context.get_dist_op_for_program(op) + if dist_op: + for var_name in op.output_arg_names: + if var_name not in sub_block_op_outputs: + sub_block_op_outputs.append(var_name) + for var_name in op.input_arg_names: + sub_block_op_inputs.add(var_name) + + # find the while op + while_op = None + for op in parent_block.ops: + if op.desc.id() == parent_while_op_id and op.type == "while": + while_op = op + break + + assert while_op is not None + + # find the actual input and output of while op + proto = OpProtoHolder.instance().get_op_proto(while_op.type) + new_X = [] + for var_name in while_op.input("X"): + if var_name in sub_block_op_inputs: + new_X.append(var_name) + assert new_X + while_op.desc.set_input(proto.inputs[0].name, new_X) + + new_Out = [] + for var_name in while_op.output("Out"): + for output_name in sub_block_op_outputs[::-1]: + if output_name.find(var_name) != -1: + new_Out.append(output_name) + assert new_Out + while_op.desc.set_output(proto.outputs[0].name, new_Out) def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id, dist_params_grads): """Remove no need vars and ops in the main program.""" _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id) + _change_while_op_input_and_output(auto_parallel_main_prog, dist_context) _remove_no_need_vars(auto_parallel_main_prog, dist_params_grads) @@ -992,8 +1182,70 @@ def remove_no_need_in_startup(auto_parallel_main_prog, startup_block._remove_op(idx) -def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, - dist_context, dist_params_grads): +def _get_process_meshes(op, program, dist_context): + """Get all process meshes when op has sub block.""" + assert op.has_attr("sub_block") + sub_block = program.blocks[op.attr("sub_block").id] + ops = sub_block.ops + op_process_mesh = dist_context.get_dist_op_for_program( + op).dist_attr.process_mesh + process_meshes = [] + for op in ops: + dist_op = dist_context.get_dist_op_for_program(op) + if not dist_op: + continue + process_mesh = dist_op.dist_attr.process_mesh + if process_mesh not in process_meshes and process_mesh != op_process_mesh: + process_meshes.append(process_mesh) + + if not process_meshes: + process_meshes.append(op_process_mesh) + + return process_meshes + + +def _is_condition_replicative(op, program, dist_context): + assert op.type == "while" + sub_block = program.blocks[op.attr("sub_block").id] + dist_op = dist_context.get_dist_op_for_program(op) + op_dist_attr = dist_op.dist_attr + + # the dims mapping of condition tensor should be replicative + for var_name in op.input("Condition"): + var = _get_var(var_name, sub_block, program) + dist_tensor = dist_context.get_dist_tensor_for_program(var) + tensor_dist_attr = dist_tensor.dist_attr + var_dims_mapping = tensor_dist_attr.dims_mapping + for dim in var_dims_mapping: + if dim != -1: + return False + + return True + + +def _get_op_process_meshes(op, dist_context): + process_meshes = [] + dist_op = dist_context.get_dist_op_for_program(op) + op_process_mesh = dist_op.dist_attr.process_mesh + for process_mesh in dist_context.process_meshes: + if set(process_mesh.processes) & ( + set(op_process_mesh.processes) + ) and len(process_mesh.processes) <= len(op_process_mesh.processes): + process_meshes.append(process_mesh) + + # it means the process mesh is not a union when process meshes is null + if not process_meshes: + process_meshes.append(op_process_mesh) + + return process_meshes + + +def reshard(auto_parallel_main_prog, + auto_parallel_startup_prog, + rank_id, + dist_context, + dist_params_grads, + batch_size=None): """ Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute. @@ -1019,65 +1271,137 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, return True return False - block = auto_parallel_main_prog.global_block() - idx = 0 - while idx < len(block.ops): - pre_op_count = len(block.ops) - op = block.ops[idx] + global while_block_info + for block_idx, block in enumerate(auto_parallel_main_prog.blocks): + if block_idx in while_block_info: + if "var_reshard_mapping" in while_block_info[block_idx]: + var_reshard_mapping = while_block_info[block_idx][ + "var_reshard_mapping"] + for op in block.ops: + for var_name in op.input_arg_names: + if var_name in var_reshard_mapping: + op.desc._rename_input(var_name, + var_reshard_mapping[var_name]) + dist_op = dist_context.get_dist_op_for_program(op) + op_dist_attr = dist_op.dist_attr + if op_dist_attr.process_mesh == while_block_info[ + block_idx]["actual_process_mesh"]: + dims_mapping = op_dist_attr.get_input_dims_mapping( + var_name) + op_dist_attr.set_input_dims_mapping( + var_reshard_mapping[var_name], dims_mapping) + op_dist_attr.set_input_dist_attr(var_name, None) + + # the outputs also need to be renamed when the output name is the same with input name + for var_name in op.output_arg_names: + if var_name in var_reshard_mapping: + op.desc._rename_output( + var_name, var_reshard_mapping[var_name]) + dist_op = dist_context.get_dist_op_for_program(op) + op_dist_attr = dist_op.dist_attr + if op_dist_attr.process_mesh == while_block_info[ + block_idx]["actual_process_mesh"]: + dims_mapping = op_dist_attr.get_output_dims_mapping( + var_name) + op_dist_attr.set_output_dims_mapping( + var_reshard_mapping[var_name], dims_mapping) + op_dist_attr.set_output_dist_attr(var_name, + None) + + idx = 0 + while idx < len(block.ops): + pre_op_count = len(block.ops) + op = block.ops[idx] + + if _is_special_op(op): + idx += 1 + continue - if _is_special_op(op): - idx += 1 - continue + dist_op = dist_context.get_dist_op_for_program(op) + if dist_op is not None: + process_meshes = [] + if op.type == "while": + if not _is_condition_replicative( + op, auto_parallel_main_prog, dist_context): + raise ValueError( + "Please check the condition due to the dims mapping is not replicative." + ) + process_meshes = _get_process_meshes( + op, auto_parallel_main_prog, dist_context) + assert process_meshes + if op.attr("sub_block").id not in while_block_info: + while_block_info[op.attr("sub_block").id] = {} + while_block_info[op.attr("sub_block").id][ + "op_id"] = op.desc.id() + while_block_info[op.attr("sub_block").id][ + "actual_process_mesh"] = _get_while_op_actual_process_mesh( + op, auto_parallel_main_prog, rank_id, dist_context) + else: + process_meshes = _get_op_process_meshes(op, dist_context) + input_vars = None + if op.type == "while": + input_var_names = op.input("X") + else: + input_var_names = op.input_arg_names + idx_offset = 0 + for var_name in op.input_arg_names: + # skip lod_tensor_blocking_queue_0 + if var_name == "lod_tensor_blocking_queue_0": + continue + var = _get_var(var_name, block, auto_parallel_main_prog) + dist_tensor = dist_context.get_dist_tensor_for_program(var) + for process_mesh in process_meshes: + if dist_tensor is not None and _need_reshard( + dist_tensor, dist_op, process_mesh, + auto_parallel_main_prog, dist_context): + reshard_op_desc = find_op_desc_seq( + dist_tensor, dist_op, process_mesh, batch_size) + parse_op_desc(block, rank_id, reshard_op_desc, + var_name, op, dist_context, + auto_parallel_main_prog, process_mesh) + cur_op_count = len(block.ops) + idx_offset = idx_offset + cur_op_count - pre_op_count + pre_op_count = cur_op_count + idx = idx + idx_offset + 1 + else: + idx += 1 - dist_op = dist_context.get_dist_op_for_program(op) - if dist_op is not None: - idx_offset = 0 - for var_name in op.input_arg_names: - # skip lod_tensor_blocking_queue_0 - if var_name == "lod_tensor_blocking_queue_0": - continue - var = block.vars[var_name] - dist_tensor = dist_context.get_dist_tensor_for_program(var) - if dist_tensor is not None and _need_reshard(dist_tensor, - dist_op): - reshard_op_desc = find_op_desc_seq(dist_tensor, dist_op) - parse_op_desc(auto_parallel_main_prog, rank_id, - reshard_op_desc, var_name, op, dist_context) - cur_op_count = len(block.ops) - idx_offset = idx_offset + cur_op_count - pre_op_count - pre_op_count = cur_op_count - idx = idx + idx_offset + 1 - else: - idx += 1 - - # insert send and recv op if output process mesh is different from tensor process mesh - idx = 0 - skip_ops = ["create_py_reader", "create_double_buffer_reader", "read"] - skip_ops += _g_special_ops - while idx < len(block.ops): - pre_op_count = len(block.ops) - op = block.ops[idx] - dist_op = dist_context.get_dist_op_for_program(op) - if dist_op is not None and op.type not in skip_ops: - for var_name in op.output_arg_names: - var = block.vars[var_name] - dist_tensor = dist_context.get_dist_tensor_for_program(var) - if dist_tensor is not None and _need_reshard(dist_tensor, - dist_op, False): - for index, item in enumerate( - dist_op.dist_attr.process_mesh.processes): - recv_rank = dist_tensor.dist_attr.process_mesh.processes[ - index] - if rank_id == item: - _insert_send_op(block, idx + 1, var, recv_rank) - if rank_id == recv_rank: - _insert_recv_op(block, idx + 1, var, item) - cur_op_count = len(block.ops) - idx_offset = idx_offset + cur_op_count - pre_op_count - pre_op_count = cur_op_count - idx = idx + idx_offset + 1 - else: - idx += 1 + # insert send and recv op if output process mesh is different from tensor process mesh + idx = 0 + # skip reader and ops whose process mesh is union + skip_ops = [ + "create_py_reader", "create_double_buffer_reader", "read", "while", + "write_to_array", "read_from_array" + ] + skip_ops += _g_special_ops + while idx < len(block.ops): + pre_op_count = len(block.ops) + op = block.ops[idx] + dist_op = dist_context.get_dist_op_for_program(op) + if dist_op is not None and op.type not in skip_ops: + for var_name in op.output_arg_names: + var = _get_var(var_name, block, auto_parallel_main_prog) + dist_tensor = dist_context.get_dist_tensor_for_program(var) + process_mesh = dist_op.dist_attr.process_mesh + if dist_tensor is not None and _need_reshard( + dist_tensor, dist_op, process_mesh, + auto_parallel_main_prog, dist_context, False): + for index, item in enumerate( + dist_op.dist_attr.process_mesh.processes): + recv_rank = dist_tensor.dist_attr.process_mesh.processes[ + index] + if rank_id == item: + _insert_send_op(block, idx + 1, var, recv_rank, + op.attr('op_role')) + if rank_id == recv_rank: + _insert_recv_op(block, idx + 1, var, item, + op.attr('op_role')) + cur_op_count = len(block.ops) + idx_offset = idx_offset + cur_op_count - pre_op_count + pre_op_count = cur_op_count + idx = idx + idx_offset + 1 + else: + idx += 1 # remove no need vars and ops in the main program remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id, diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_autoconvert.py b/python/paddle/fluid/tests/unittests/auto_parallel_autoconvert.py index 2277c69674b3faf4f2fddc43b8032152a465fd42..22692fa5debfccea65da32837819818023eaf80b 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel_autoconvert.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_autoconvert.py @@ -32,6 +32,7 @@ from paddle.fluid.initializer import NumpyArrayInitializer from paddle.distributed.auto_parallel.utils import save_distributed_checkpoint, load_distributed_checkpoint, load_checkpoint_into_program from paddle.distributed.auto_parallel.utils import get_dist_attr, merge_and_slice_parameter, load_parameter_into_program from paddle.distributed.auto_parallel.reshard import HAS_SENT, HAS_RECV, HAS_ALLGATHER +from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context paddle.enable_static() _global_parallel_strategy = None @@ -185,6 +186,7 @@ class TestMLPAutoConvert(unittest.TestCase): str(paddle.distributed.get_rank()))) def test_mlp_mp2pp(self): + set_default_distributed_context(None) global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh @@ -211,6 +213,7 @@ class TestMLPAutoConvert(unittest.TestCase): fetch_list=[loss]) last_res = res[0] + set_default_distributed_context(None) _global_parallel_strategy = "pp" _global_process_mesh = auto.ProcessMesh([0, 1]) global PP_MESH_0 @@ -266,6 +269,7 @@ class TestMLPAutoConvert2(unittest.TestCase): str(paddle.distributed.get_rank()))) def test_mlp_pp2mp(self): + set_default_distributed_context(None) global _global_parallel_strategy _global_parallel_strategy = "pp" global _global_process_mesh @@ -302,6 +306,7 @@ class TestMLPAutoConvert2(unittest.TestCase): if paddle.distributed.get_rank() in [1]: last_res = res[0] + set_default_distributed_context(None) _global_parallel_strategy = "mp" _global_process_mesh = auto.ProcessMesh([0, 1]) @@ -345,6 +350,7 @@ class TestMLPAutoConvertInvalid(unittest.TestCase): np.random.seed(2021) def test_input_invalid(self): + set_default_distributed_context(None) global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh