# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License from functools import reduce import paddle import paddle.fluid.layers.utils as utils from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.framework import LayerHelper, OpProtoHolder, Program, core from paddle.utils import unique_name from .cost import ( AllgatherOpCost, CommContext, ConcatOpCost, SendOpCost, SliceOpCost, SplitOpCost, build_comm_desc, ) from .dist_attribute import TensorDistAttr from .dist_context import DistributedContext from .process_group import new_process_group from .utils import is_gradient_clip_op # NOTE: If op in _g_special_ops or _g_gradient_clip_ops, it will not be resharded. _g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling'] _g_gradient_clip_ops = [ "sum", "sqrt", "fill_constant", "elementwise_max", "elementwise_div", ] _g_subblock_ops = ["while", "conditional_block"] def get_var_with_recursion(var_name, block, program): """Get var in the parent block if not found in the current block""" var = None if var_name in block.vars: var = block.vars[var_name] else: var = block._var_recursive(var_name) # parent_block = program.blocks[block.parent_idx] # if var_name in parent_block.vars: # var = parent_block.vars[var_name] assert var is not None, "{} is not found".format(var.name) return var class AllGatherOpDesc: """ Describe the allgather op in the reshard phase. Args: group (list): Process group. shape (list): The tensor shape. is_bool (bool): Whether allgather bool data. Default: False. """ def __init__(self, group, shape, is_bool=False): self._group = group self._desc = "all_gather" self._shape = shape self._is_bool = is_bool @property def is_bool(self): return self._is_bool @property def group(self): return self._group @property def desc(self): return self._desc @property def shape(self): return self._shape def __repr__(self): return f"op: {self._desc}, group: {self._group}, shape: {self._shape}, is_bool: {self._is_bool}." class AllGatherConcatOpDesc: """ Describe the c_concat op in the reshard phase. Args: group (list): Process group. shape (list): The tensor shape. is_bool (bool): Whether c_concat bool data. Default: False. """ def __init__(self, group, shape, is_bool=False): self._group = group self._desc = "c_concat" self._shape = shape self._is_bool = is_bool @property def is_bool(self): return self._is_bool @property def group(self): return self._group @property def desc(self): return self._desc @property def shape(self): return self._shape def __repr__(self): return f"op: {self._desc}, group: {self._group}, shape: {self._shape}, is_bool: {self._is_bool}." class SendOpDesc: """ Describe the send op in the reshard phase. Args: partition_index (list): The index of partition in complete tensor. src (int): The source process to send. dst (int): The destination process to receive. is_bool (bool): Whether send bool data. Default: False. """ def __init__(self, partition_index, src, dst, is_bool=False): self._dst = dst self._partition_index = partition_index self._desc = "send" self._shape = [] self._is_bool = is_bool self._src = src @property def src(self): return self._src @property def is_bool(self): return self._is_bool @property def partition_index(self): return self._partition_index @property def dst(self): return self._dst @property def desc(self): return self._desc @property def shape(self): if not self._shape: for item in self.partition_index: self._shape.append(item[1] - item[0]) return self._shape def __repr__(self): return f"op: {self._desc}, partition_index: {self._partition_index}, dst: {self._dst}, shape: {self._shape}, is_bool: {self._is_bool}." class RecvOpDesc: """ Describe the recv op in the reshard op. Args: partition_index (list): The index of partition in complete tensor. src (int): The source process to send. dst (int): The destination process to receive. is_bool (bool): Whether receive bool data. Default: False. """ def __init__(self, partition_index, src, dst, is_bool=False): self._src = src self._partition_index = partition_index self._desc = "recv" self._shape = [] self._is_bool = is_bool self._dst = dst @property def dst(self): return self._dst @property def is_bool(self): return self._is_bool @property def partition_index(self): return self._partition_index @property def src(self): return self._src @property def desc(self): return self._desc @property def shape(self): if not self._shape: for item in self.partition_index: self._shape.append(item[1] - item[0]) return self._shape def __repr__(self): return f"op: {self._desc}, partition_index: {self._partition_index}, dst: {self._dst}, shape: {self._shape}, is_bool: {self._is_bool}." class SliceOpDesc: """ Describe the slice op in the reshard phase. Args: starts (list): It represents start indices of corresponding axis in ``axes``. ends (list): It represents end indices of corresponding axis in ``axes``. axes (list): Axes that `starts` and `ends` apply to. shape (list): The shape of the tensor to be sliced. """ def __init__(self, starts, ends, axes, shape=None): self._starts = starts self._ends = ends self._axes = axes self._desc = "slice" self._shape = shape @property def starts(self): return self._starts @property def ends(self): return self._ends @property def axes(self): return self._axes @property def desc(self): return self._desc @property def shape(self): return self._shape def __repr__(self): if self._shape is not None: return f"op: {self._desc}, starts: {self._starts}, ends: {self._ends}, axes: {self._axes}, shape: {self._shape}." else: return f"op: {self._desc}, starts: {self._starts}, ends: {self._ends}, axes: {self._axes}." class ConcatOpDesc: """ Describe the concat op in the reshard phase. Args: partition_index_list (list): The list contains all partition index. """ def __init__(self, partition_index_list): self._partition_index_list = partition_index_list self._desc = "concat" @property def partition_index_list(self): return self._partition_index_list @property def desc(self): return self._desc def __repr__(self): return f"op: {self._desc}, partition_index_list: {self._partition_index_list}." class Inserter: """Insert op required in the reshard process.""" @staticmethod def insert_cast_op(block, idx, tensor, op_role, tensor_type): # to avoid name conflict with framework new_var_name = paddle.utils.unique_name.generate_with_ignorable_key( ".".join(["cast@RESHARD", 'tmp']) ) out = block.create_var( name=new_var_name, dtype=tensor_type, type=tensor.type, lod_level=tensor.lod_level, ) cast_op = block._insert_op( idx, type='cast', inputs={'X': [tensor]}, outputs={'Out': [out]}, attrs={ 'in_dtype': tensor.dtype, 'out_dtype': out.dtype, 'op_role': op_role, }, ) cast_op._set_attr('op_namescope', "/auto_parallel/reshard") return out @staticmethod def insert_send_op(block, idx, tensor, src, dst, op_role): """Insert send op into block at the given index.""" op_type = 'send_v2' # use pair comm group process_group = new_process_group([src, dst]) send_op = block._insert_op( idx, type=op_type, inputs={'X': [tensor]}, attrs={ 'ring_id': process_group.id, 'peer': process_group.ranks.index(dst), 'use_calc_stream': True, 'op_role': op_role, 'dynamic_shape': True, }, ) send_op._set_attr('op_namescope', "/auto_parallel/reshard") @staticmethod def insert_recv_op(block, idx, tensor, src, dst, op_role): """Insert recv op into block at the given index.""" op_type = 'recv_v2' # use pair group process_group = new_process_group([src, dst]) recv_op = block._insert_op( idx, type=op_type, inputs={'X': [tensor]}, outputs={'Out': [tensor]}, attrs={ 'ring_id': process_group.id, 'peer': process_group.ranks.index(src), 'out_shape': tensor.shape, 'dtype': tensor.dtype, 'use_calc_stream': True, 'op_role': op_role, 'dynamic_shape': True, }, ) recv_op._set_attr('op_namescope', "/auto_parallel/reshard") @staticmethod def insert_reset_lod_op(block, idx, X, Y, op_role): """Insert reset_lod op into block at the given index.""" new_var_name = paddle.utils.unique_name.generate_with_ignorable_key( ".".join(["reset_lod@RESHARD", 'tmp']) ) reset_lod_out = block.create_var( name=new_var_name, shape=X.shape, type=X.type, dtype=X.dtype, lod_level=X.lod_level, ) reset_op = block._insert_op( idx, type="lod_reset", inputs={'X': X, 'Y': Y}, outputs={'Out': reset_lod_out}, attrs={'op_role': op_role}, ) reset_op._set_attr('op_namescope', "/auto_parallel/reshard") return reset_lod_out @staticmethod def insert_concat_op(block, idx, tensors, axis, op_role): """Insert concat op into block at the given block.""" inputs = {'X': tensors} attrs = {} attrs['axis'] = axis attrs['op_role'] = op_role # to avoid name conflict with framework helper = LayerHelper('concat@RESHARD', **locals()) with paddle.static.program_guard(block.program): out = block.create_var( name=paddle.utils.unique_name.generate_with_ignorable_key( ".".join([helper.name, 'tmp']) ), dtype=tensors[0].dtype, shape=None, lod_level=tensors[0].lod_level, type=tensors[0].type, persistable=False, stop_gradient=False, ) concat_op = block._insert_op( idx, type='concat', inputs=inputs, outputs={'Out': [out]}, attrs=attrs, ) concat_op._set_attr('op_namescope', "/auto_parallel/reshard") return out @staticmethod def insert_slice_op( block, idx, tensor, starts, ends, axes, new_var_name, op_role ): """Insert slice op into block at the given block.""" # This is a hack to insert split op to get slice tensor # 1. [128, 128] => [64, 128]: split # 2. [128, 128] => [128, 128]: assign # 3. [128, 128] => [64, 64]: slice, it will replaced by multi split global_shape = tensor.shape slice_shape = [ends[i] - starts[i] for i in range(len(starts))] diff_dims = [] for index, item in enumerate(slice_shape): if item != global_shape[index]: diff_dims.append(index) # use assign if len(diff_dims) == 0: out = block.create_var( name=new_var_name, dtype=tensor.dtype, type=tensor.type, shape=slice_shape, lod_level=tensor.lod_level, ) inputs = {'X': [tensor]} outputs = {"Out": [out]} attrs = {"in_place": False} slice_op = block._insert_op( idx, type="assign", inputs=inputs, outputs=outputs, attrs=attrs ) slice_op._set_attr('op_namescope', "/auto_parallel/reshard") return out # use split once elif len(diff_dims) == 1: diff_dim = diff_dims[0] num_or_sections = global_shape[diff_dim] // slice_shape[diff_dim] axis = diff_dim cur_idx = starts[diff_dim] // slice_shape[diff_dim] input_shape = global_shape inputs = {'X': tensor} attrs = {'num': num_or_sections, 'axis': axis, 'op_role': op_role} new_shape = [] for index, item in enumerate(tensor.shape): if index != axis: new_shape.append(item) else: new_shape.append(item // num_or_sections) with paddle.static.program_guard(block.program): outs = [ block.create_var( name=paddle.utils.unique_name.generate_with_ignorable_key( ".".join(['split@RESHARD', 'tmp']) ), dtype=tensor.dtype, shape=None, type=tensor.type, persistable=False, lod_level=tensor.lod_level, stop_gradient=False, ) for i in range(num_or_sections) ] out = outs[cur_idx] split_op = block._insert_op( idx, type="split", inputs=inputs, outputs={'Out': outs}, attrs=attrs, ) split_op._set_attr('op_namescope', "/auto_parallel/reshard") return out # use slice else: inputs = {'Input': tensor} infer_flags = list(1 for i in range(len(axes))) attrs = { "axes": axes, "starts": starts, "ends": ends, "infer_flags": infer_flags, 'op_role': op_role, } out = block.create_var( name=new_var_name, dtype=tensor.dtype, type=tensor.type, lod_level=tensor.lod_level, ) slice_op = block._insert_op( idx, type="slice", inputs=inputs, outputs={'Out': [out]}, attrs=attrs, ) slice_op._set_attr('op_namescope', "/auto_parallel/reshard") return out @staticmethod def insert_split_op(block, idx, tensor, num_or_sections, op_role, axis=0): """Insert split op into block at the given index.""" helper = LayerHelper('split@RESHARD', **locals()) input_shape = tensor.shape inputs = {'X': tensor} attrs = {'num': num_or_sections, 'axis': axis, 'op_role': op_role} new_shape = [] for index, item in enumerate(tensor.shape): if index != axis: new_shape.append(item) else: new_shape.append(item // num_or_sections) with paddle.static.program_guard(block.program): outs = [ block.create_var( name=paddle.utils.unique_name.generate_with_ignorable_key( ".".join([helper.name, 'tmp']) ), dtype=tensor.dtype, shape=None, lod_level=tensor.lod_level, type=tensor.type, persistable=False, stop_gradient=False, ) for i in range(num_or_sections) ] split_op = block._insert_op( idx, type="split", inputs=inputs, outputs={'Out': outs}, attrs=attrs ) split_op._set_attr('op_namescope', "/auto_parallel/reshard") return outs @staticmethod def insert_fill_constant_op(block, idx, op_role): """Insert fill constant op into block at the given index.""" # to avoid name conflict with framework helper = LayerHelper('fill_constant@RESHARD', **locals()) # use paddle.int64 as dtype with paddle.static.program_guard(block.program): out = block.create_var( name=paddle.utils.unique_name.generate_with_ignorable_key( ".".join([helper.name, 'tmp']) ), dtype=paddle.int64, shape=None, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, stop_gradient=False, ) inputs = {} attrs = {'force_cpu': False} attrs['str_value'] = str(int("1")) attrs['value'] = int("1") attrs['dtype'] = out.dtype attrs['op_role'] = op_role utils.get_shape_tensor_inputs( inputs=inputs, attrs=attrs, shape=[0], op_type='fill_constant' ) fillconstant_op = block._insert_op( idx, type='fill_constant', inputs=inputs, outputs={'Out': [out]}, attrs=attrs, ) out.stop_gradient = True fillconstant_op._set_attr('op_namescope', "/auto_parallel/reshard") return out @staticmethod def insert_allgather_op(block, idx, tensor, ranks, op_role): """Insert allgather op into block at the given index.""" tensor_list = [] group = new_process_group(ranks) idx_offset = 0 # instant process group before insert allgather op. if not group.is_instantiate(): # insert fill_constant op fill_constant_out = Inserter.insert_fill_constant_op( block, idx, op_role ) fill_constant_out.stop_gradient = True # insert c_allreduce_sum op allreduce_op = block._insert_op( idx + 1, type="c_allreduce_sum", inputs={'X': [fill_constant_out]}, outputs={'Out': [fill_constant_out]}, attrs={ 'ring_id': 0, 'use_calc_stream': True, 'op_role': op_role, }, ) allreduce_op._set_attr('op_namescope', "/auto_parallel/reshard") # insert c_sync_calc_stream op sync_calc_op = block._insert_op( idx + 2, type="c_sync_calc_stream", inputs={'X': [fill_constant_out]}, outputs={'Out': [fill_constant_out]}, attrs={'op_role': op_role}, ) sync_calc_op._set_attr('op_namescope', "/auto_parallel/reshard") idx_offset = 3 # insert c_allgather op op_type = 'c_allgather' # to avoid name conflict with framework helper = LayerHelper(op_type + "@RESHARD", **locals()) with paddle.static.program_guard(block.program): allgather_out = block.create_var( name=paddle.utils.unique_name.generate_with_ignorable_key( ".".join([helper.name, 'tmp']) ), dtype=tensor.dtype, shape=None, lod_level=tensor.lod_level, type=tensor.type, persistable=False, stop_gradient=False, ) allgather_op = block._insert_op( idx + idx_offset, type=op_type, inputs={'X': [tensor]}, outputs={'Out': [allgather_out]}, attrs={ 'ring_id': group.id, 'use_calc_stream': True, 'nranks': group.nranks, 'op_role': op_role, }, ) allgather_op._set_attr('op_namescope', "/auto_parallel/reshard") idx_offset += 1 # insert split op split_out = Inserter.insert_split_op( block, idx + idx_offset, allgather_out, group.nranks, op_role ) idx_offset += 1 tensor_list.extend(split_out) return tensor_list, idx_offset @staticmethod def insert_c_concat_op(block, idx, tensor, ranks, op_role): """Insert c_concat op into block at the given index.""" group = new_process_group(ranks) idx_offset = 0 # insert c_concat op op_type = 'c_concat' # to avoid name conflict with framework helper = LayerHelper(op_type + "@RESHARD", **locals()) with paddle.static.program_guard(block.program): c_concat_out = block.create_var( name=paddle.utils.unique_name.generate_with_ignorable_key( ".".join([helper.name, 'tmp']) ), dtype=tensor.dtype, shape=None, lod_level=tensor.lod_level, type=tensor.type, persistable=False, stop_gradient=False, ) cur_rank = paddle.distributed.get_rank() c_concat_op = block._insert_op( idx + idx_offset, type=op_type, inputs={'X': [tensor]}, outputs={'Out': [c_concat_out]}, attrs={ 'ring_id': group.id, 'use_calc_stream': True, 'use_model_parallel': True, 'nranks': group.nranks, 'op_role': op_role, 'rank': group.ranks.index(cur_rank) if cur_rank in ranks else 0, }, ) c_concat_op._set_attr('op_namescope', "/auto_parallel/reshard") return c_concat_out @staticmethod def concat_partitions_with_op( partition_tensor_list, tensor, partition_index, block, idx, op_role ): """Concat the tensors and insert concat op.""" if not partition_tensor_list: partition_tensor_list.append((tensor, partition_index)) else: i = 0 has_concat = False while i < len(partition_tensor_list): ( concat_axis, first_order, new_partition, ) = Resharder.compute_concat_info( partition_tensor_list[i][1], partition_index ) if concat_axis != -1: has_concat = True _ = ( Inserter.insert_concat_op( block, idx[0], [partition_tensor_list[i][0], tensor], concat_axis, op_role, ) if first_order == 0 else Inserter.insert_concat_op( block, idx[0], [tensor, partition_tensor_list[i][0]], concat_axis, op_role, ) ) partition_tensor_list.pop(i) idx[0] += 1 Inserter.concat_partitions_with_op( partition_tensor_list, _, new_partition, block, idx, op_role, ) break i += 1 if not has_concat: partition_tensor_list.append((tensor, partition_index)) class Remover: """Remove var and op in the reshard process.""" @staticmethod def remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id): """Remove no need ops in the main program""" not_remove_op_ref = [ "create_py_reader", "create_double_buffer_reader", "read", ] # NOTE: The nested sub block is not be supported now. remove_block_order = [] for block_idx in Resharder.while_block_info: remove_block_order.append(block_idx) for block_idx, block in enumerate(auto_parallel_main_prog.blocks): if block_idx not in remove_block_order: remove_block_order.append(block_idx) # the sub block should be removed first for block_idx in remove_block_order: remove_op_idx = [] block = auto_parallel_main_prog.blocks[block_idx] ops = block.ops vars = block.vars for idx, op in enumerate(ops): if op.type == "read": dim_list = [] for var_name in op.output_arg_names: dim_list.extend( get_var_with_recursion( var_name, block, auto_parallel_main_prog ).shape ) for i in range(idx, -1, -1): if ops[i].type == "create_py_reader": ops[i]._set_attr("shape_concat", dim_list) break continue # replace the input and output of c_sync_comm_stream op when in pipeline scene. if op.type == "c_sync_comm_stream": need_save = [] for var_name in op.input_arg_names: process_mesh = ( dist_context.get_tensor_dist_attr_for_program( get_var_with_recursion( var_name, block, auto_parallel_main_prog ) ).process_mesh ) if rank_id in process_mesh.process_ids: need_save.append(var_name) if not need_save: remove_op_idx.append(idx) continue proto = OpProtoHolder.instance().get_op_proto(op.type) op.desc.set_input(proto.inputs[0].name, need_save) op.desc.set_output(proto.outputs[0].name, need_save) continue # judge the other op whether should be removed. op_dist_attr = dist_context.get_op_dist_attr_for_program(op) if op_dist_attr is not None: op_process_mesh = op_dist_attr.process_mesh if ( rank_id not in op_process_mesh.process_ids and op.type not in not_remove_op_ref ): remove_op_idx.append(idx) for idx in remove_op_idx[::-1]: block._remove_op(idx) @staticmethod def remove_no_need_vars( auto_parallel_main_prog, dist_params_grads, feed_var_names ): """Remove no need vars in the main program""" for block_idx, block in enumerate(auto_parallel_main_prog.blocks): remove_vars = set() ops = block.ops vars = block.vars need_vars = set() for op in ops: for var_name in op.input_arg_names: if var_name in vars: need_vars.add(var_name) for var_name in op.output_arg_names: if var_name in vars: need_vars.add(var_name) for var in vars: if var not in need_vars: remove_vars.add(var) # change dist_params_grads, the optimize op just in block 0. if block_idx == 0: param_grad_map = {} for op in ops: if int(op.attr('op_role')) == int(OpRole.Optimize): if ( "Param" in op.input_names and "Grad" in op.input_names ): param_name = op.input("Param")[0] grad_name = op.input("Grad")[0] param_grad_map[param_name] = grad_name need_remove_idx = [] for idx, item in enumerate(dist_params_grads): if item[0].name not in param_grad_map.keys(): need_remove_idx.append(idx) for idx in need_remove_idx[::-1]: dist_params_grads.pop(idx) idx = 0 while idx < len(dist_params_grads): param_name = dist_params_grads[idx][0].name grad_name = dist_params_grads[idx][1].name if grad_name != param_grad_map[param_name]: dist_params_grads[idx] = ( vars[param_name], vars[param_grad_map[param_name]], ) idx += 1 for var in remove_vars: if var in feed_var_names: continue block._remove_var(var) @staticmethod def remove_no_need_in_main( auto_parallel_main_prog, dist_context, rank_id, dist_params_grads ): """Remove no need vars and ops in the main program.""" Remover.remove_no_need_ops( auto_parallel_main_prog, dist_context, rank_id ) Resharder.change_while_op_input_and_output( auto_parallel_main_prog, dist_context ) # 'feed_var_names' cannot be removed from auto_parallel_main_prog feed_var_names = [] for var in sum(list(dist_context.serial_feed_vars.values()), []): feed_var_names.append(var.name) Remover.remove_no_need_vars( auto_parallel_main_prog, dist_params_grads, feed_var_names ) @staticmethod def remove_no_need_in_startup( auto_parallel_main_prog, auto_parallel_startup_prog ): """Remove no need vars and ops in the startup program.""" main_input_vars = set() main_ops = auto_parallel_main_prog.global_block().ops for op in main_ops: for var_name in op.input_arg_names: main_input_vars.add(var_name) startup_block = auto_parallel_startup_prog.global_block() startup_output_vars = set() startup_ops = startup_block.ops for op in startup_ops: # skip c_sync_comm_stream op if op.type == "c_sync_comm_stream": continue for var_name in op.output_arg_names: startup_output_vars.add(var_name) need_vars = set() for var_name in startup_output_vars: if var_name in main_input_vars: need_vars.add(var_name) startup_ops = startup_block.ops actual_need_vars = set() for idx, op in enumerate(startup_ops): is_need_op = False if op.type == "c_sync_comm_stream": continue for var_name in op.output_arg_names: if var_name in need_vars: is_need_op = True break if is_need_op: for var_name in op.output_arg_names: actual_need_vars.add(var_name) for var_name in op.input_arg_names: actual_need_vars.add(var_name) remove_vars = set() for var_name in startup_block.vars: if var_name not in actual_need_vars: remove_vars.add(var_name) for var in remove_vars: startup_block._remove_var(var) remove_op_idx = [] vars = startup_block.vars for idx, op in enumerate(startup_block.ops): is_no_need_op = False if op.type == "c_sync_comm_stream": var_names = [] for var_name in op.input_arg_names: if var_name in vars: var_names.append(var_name) if not var_names: remove_op_idx.append(idx) else: proto = OpProtoHolder.instance().get_op_proto(op.type) op.desc.set_input(proto.inputs[0].name, var_names) op.desc.set_output(proto.outputs[0].name, var_names) continue for var_name in op.output_arg_names: if var_name not in vars: is_no_need_op = True break if is_no_need_op: remove_op_idx.append(idx) for idx in remove_op_idx[::-1]: startup_block._remove_op(idx) class Resharder: """ Reshard tensor in the program according to its distributed attribute and corresponding op distributed attribute. Args: auto_parallel_main_prog (Program): An auto parallel main program. auto_parallel_startup_prog (Program): An auto parallel startup program. rank_id (int): The process id. dist_context (DistributedContext): The distributed context of this rank. dist_params_grads (list): The list contains the tuple of param and grad. batch_size (int): The batch size. Default: None. """ while_block_info = {} def __init__( self, auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, dist_context, dist_params_grads, batch_size=None, ): assert isinstance(auto_parallel_main_prog, Program), ( "The type of auto_parallel_main_prog should be Program, " "but got {}.".format(type(auto_parallel_main_prog)) ) if auto_parallel_startup_prog is not None: assert isinstance(auto_parallel_main_prog, Program), ( "The type of auto_parallel_startup_prog should be Program or None, " "but got {}.".format(type(auto_parallel_startup_prog)) ) assert isinstance( rank_id, int ), "The type of rank_id should be int, " "but got {}.".format( type(rank_id) ) assert isinstance(dist_context, DistributedContext), ( "The type of dist_context should be DistributedContext, " "but got {}.".format(type(dist_context)) ) if batch_size is not None: assert isinstance( batch_size, int ), "The type of batch_size should be int, " "but got {}.".format( type(batch_size) ) self._auto_parallel_main_prog = auto_parallel_main_prog self._auto_parallel_startup_prog = auto_parallel_startup_prog self._rank_id = rank_id self._dist_context = dist_context self._dist_params_grads = dist_params_grads self._batch_size = batch_size self._has_sent = {} self._has_recv = {} self._has_allgather = {} # to avoid reshard repeatly self._has_resharded = {} @property def auto_parallel_main_prog(self): return self._auto_parallel_main_prog @property def auto_parallel_startup_prog(self): return self._auto_parallel_startup_prog @property def rank_id(self): return self._rank_id @property def dist_context(self): return self._dist_context @property def dist_params_grads(self): return self._dist_params_grads @property def batch_size(self): return self._batch_size @property def has_sent(self): return self._has_sent @property def has_recv(self): return self._has_recv @property def has_allgather(self): return self._has_allgather @staticmethod def compute_partition_shape(complete_shape, dims_mapping, process_shape): """Compute the shape of partition.""" partition_shape = [] for idx, item in enumerate(complete_shape): if dims_mapping[idx] == -1: partition_shape.append(item) else: partition_shape.append(item // process_shape[dims_mapping[idx]]) return partition_shape @staticmethod def compute_process_index(process, process_group, process_shape): """Compute the index of process_shape corresponding to the process.""" relative_process = process_group.index(process) process_index = [] product = reduce(lambda x, y: x * y, process_shape) for i in range(len(process_shape)): idx = relative_process // (product // process_shape[i]) product = product // process_shape[i] relative_process = ( relative_process - relative_process // product * product ) process_index.append(idx) return process_index @staticmethod def compute_partition_index( process, complete_shape, dims_mapping, process_shape, process_group ): """Compute the partition index in complete tensor.""" partition_shape = Resharder.compute_partition_shape( complete_shape, dims_mapping, process_shape ) process_index = Resharder.compute_process_index( process, process_group, process_shape ) partition_index = [] for i in range(len(complete_shape)): if dims_mapping[i] == -1: partition_index.append([0, partition_shape[i]]) else: partition_index.append( [ process_index[dims_mapping[i]] * partition_shape[i], (process_index[dims_mapping[i]] + 1) * partition_shape[i], ] ) return partition_index @staticmethod def compute_concat_info(partition_index_x, partition_index_y): """Judge whether two partition can be concatenated and compute concatenated partition index.""" differ_count = 0 concat_axis = -1 first_order = 0 new_partition = [] for idx, item in enumerate(partition_index_x): if item != partition_index_y[idx]: differ_count += 1 if ( item[1] == partition_index_y[idx][0] and item[0] < partition_index_y[idx][1] ): concat_axis = idx new_partition.append([item[0], partition_index_y[idx][1]]) elif ( item[0] == partition_index_y[idx][1] and item[1] > partition_index_y[idx][0] ): first_order = 1 concat_axis = idx new_partition.append([partition_index_y[idx][0], item[1]]) else: new_partition.append(item) if differ_count == 1: return concat_axis, first_order, new_partition else: return -1, first_order, new_partition @staticmethod def compute_complete_shape(slice_shape, process_shape, dims_mapping): """compute the complete shape of the slice tensor with its process mesh and dims mapping""" complete_shape = [] for idx, item in enumerate(slice_shape): if dims_mapping[idx] == -1: complete_shape.append(item) else: complete_shape.append(item * process_shape[dims_mapping[idx]]) return complete_shape @staticmethod def concat_partitions(partition_index_list, partition_index): """Concat the given partitions without inserting concat op.""" if not partition_index_list: partition_index_list.append(partition_index) else: i = 0 has_concat = False while i < len(partition_index_list): concat_axis, _, new_partition = Resharder.compute_concat_info( partition_index_list[i], partition_index ) if concat_axis != -1: has_concat = True partition_index_list.pop(i) Resharder.concat_partitions( partition_index_list, new_partition ) break i += 1 if not has_concat: partition_index_list.append(partition_index) @staticmethod def change_while_op_input_and_output(auto_parallel_main_prog, dist_context): """Change while op input and output after the corresponding sub block ops removed""" for sub_block_idx in Resharder.while_block_info: sub_block = auto_parallel_main_prog.blocks[sub_block_idx] parent_while_op_id = Resharder.while_block_info[sub_block_idx][ "op_id" ] parent_block = auto_parallel_main_prog.blocks[sub_block.parent_idx] sub_block_op_inputs = set() sub_block_op_outputs = [] for op in sub_block.ops: # skip the input and output of operators inserted in the reshard phase dist_op = dist_context.get_dist_op_for_program(op) if ( dist_op or (op.type == "slice" and not dist_op) or (op.type == "split" and not dist_op) or (op.type == "assign" and not dist_op) ): for var_name in op.output_arg_names: if var_name not in sub_block_op_outputs: sub_block_op_outputs.append(var_name) for var_name in op.input_arg_names: sub_block_op_inputs.add(var_name) # find the while op while_op = None for op in parent_block.ops: if op.desc.id() == parent_while_op_id and op.type == "while": while_op = op break if while_op is None: continue # find the actual input and output of while op proto = OpProtoHolder.instance().get_op_proto(while_op.type) new_X = [] for var_name in while_op.input("X"): if var_name in sub_block_op_inputs: new_X.append(var_name) assert new_X new_X.sort() while_op.desc.set_input(proto.inputs[0].name, new_X) new_Out = [] for var_name in while_op.output("Out"): for output_name in sub_block_op_outputs[::-1]: if output_name.find(var_name) != -1 and ( len(var_name) == len(output_name) or "@RESHARD" in output_name ): if output_name not in new_Out: new_Out.append(output_name) assert new_Out while_op.desc.set_output(proto.outputs[0].name, new_Out) def is_overlapped(self, shape_x, shape_y): """Judge whether two partitions intersect on the specified dimension.""" overlapped = False if (shape_y[0] <= shape_x[0] < shape_y[1]) or ( shape_x[0] <= shape_y[0] < shape_x[1] ): overlapped = True return overlapped def is_unshard(self, dims_mapping): for dim in dims_mapping: if dim != -1: return False return True def is_special_op(self, op): global _g_special_ops, _g_gradient_clip_ops if op.type in _g_special_ops: return True if is_gradient_clip_op(op) and op.type in _g_gradient_clip_ops: return True return False def is_condition_replicative(self, op): sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id] if op.type == "while": input_cond = op.input("Condition") elif op.type == "conditional_block": input_cond = op.input("Cond") # the dims mapping of condition tensor should be replicative for var_name in input_cond: var = get_var_with_recursion( var_name, sub_block, self.auto_parallel_main_prog ) dist_tensor = self.dist_context.get_dist_tensor_for_program(var) tensor_dist_attr = dist_tensor.dist_attr var_dims_mapping = tensor_dist_attr.dims_mapping for dim in var_dims_mapping: if dim != -1: return False return True def need_reshard(self, dist_tensor, dist_attr, op_input=True, dist_op=None): """Judge the tensor whether needs to be resharded.""" is_reshard = False tensor_dist_attr = dist_tensor.dist_attr tensor_dims_mapping = tensor_dist_attr.dims_mapping tensor_process_mesh = tensor_dist_attr.process_mesh # dist_attr is [process_mesh, dims_mapping] and process_mesh is not a union op_process_mesh = dist_attr[0] if op_input: op_input_dims_mapping = dist_attr[1] if all( map( lambda x: x, [ tensor_dims_mapping, tensor_process_mesh, op_input_dims_mapping, op_process_mesh, ], ) ): # judge whether need reshard by dims_mapping if tensor_dims_mapping != op_input_dims_mapping: if ( tensor_process_mesh not in self.dist_context.process_meshes ): # assert whether -1 when union. for item in tensor_dims_mapping: if item != -1: raise ValueError( "The dim must be -1 when tensor process mesh is a union." ) # tensor process_mesh: [0, 1, 2, 3], dims_mapping: [-1, -1] # op process_mesh: [4, 5], dims_mapping: [0, -1] # reshard is not supported such as above if not is_reshard: return is_reshard else: raise ValueError( "it is not supported that tensor process mesh is a union and needs reshard." ) is_reshard = True # judge whether need reshard by process_mesh if tensor_process_mesh != op_process_mesh: is_reshard = True else: op_output_dims_mapping = dist_attr[1] if all( map( lambda x: x, [ tensor_dims_mapping, tensor_process_mesh, op_output_dims_mapping, op_process_mesh, ], ) ): if tensor_dims_mapping != op_output_dims_mapping: raise ValueError( "It is not supported that tensor dims mapping is different from op output dims mapping." ) if tensor_process_mesh != op_process_mesh: is_reshard = True return is_reshard def get_op_process_meshes(self, op): """Get sub process meshes of the given op if op process mesh is a union.""" process_meshes = [] dist_op = self.dist_context.get_dist_op_for_program(op) op_process_mesh = dist_op.dist_attr.process_mesh for process_mesh in self.dist_context.process_meshes: if set(process_mesh.process_ids) & ( set(op_process_mesh.process_ids) ) and len(process_mesh.process_ids) < len( op_process_mesh.process_ids ): process_meshes.append(process_mesh) # it means the process mesh is not a union when process meshes is null if not process_meshes: process_meshes.append(op_process_mesh) return process_meshes def find_op_desc_seq(self, dist_tensor, dist_attr, serial=False): """ Find the op description sequence to reshard the source tensor for matching the op requirement. Args: dist_tensor (DistributedTensor): A distributed tensor. dist_attr (list): A list contains process_mesh and dims_mapping such as [process_mesh, dims_mapping]. serial (bool): If serial is true, the dist tensor and dist op come from serial program. Otherwise, they come from auto program. Returns: Dict, the dict represents the required op description sequence corresponding to process, The key of dict is process and value is a list containing op description. """ tensor_dist_attr = dist_tensor.dist_attr source_tensor = dist_tensor.serial_tensor tensor_name = source_tensor.name source_dims_mapping = tensor_dist_attr.dims_mapping source_process_mesh = tensor_dist_attr.process_mesh source_process_group = source_process_mesh.process_ids source_process_shape = source_process_mesh.shape target_process_mesh = dist_attr[0] target_dims_mapping = dist_attr[1] target_process_group = target_process_mesh.process_ids target_process_shape = target_process_mesh.shape if source_tensor.shape[0] < 0: assert source_tensor.shape[0] == -1 new_shape = list(source_tensor.shape) new_shape[0] = self.batch_size source_tensor.desc.set_shape(new_shape) complete_shape = ( Resharder.compute_complete_shape( source_tensor.shape, source_process_shape, source_dims_mapping ) if not serial else source_tensor.shape ) op_desc_seq = {} # TODO: if the target process group has the same process with source process group if set(target_process_group).intersection( set(source_process_group) ) and set(target_process_group).difference(set(source_process_group)): pass elif target_process_group != source_process_group: partition_process_mapping_list = [] for source_process in source_process_group: # get partition index of source process source_partition_index = Resharder.compute_partition_index( source_process, complete_shape, source_dims_mapping, source_process_shape, source_process_group, ) if not partition_process_mapping_list: # the item in partition_process_mapping_list is source_partition_index, which processes and whether has been used partition_process_mapping_list.append( [source_partition_index, [source_process], [False]] ) else: partition_list = list( [item[0] for item in partition_process_mapping_list] ) process_list = list( [item[1] for item in partition_process_mapping_list] ) has_used = list( [item[2] for item in partition_process_mapping_list] ) if partition_list.count(source_partition_index) == 1: index = partition_list.index(source_partition_index) process_list[index].append(source_process) has_used[index].append(False) else: partition_process_mapping_list.append( [source_partition_index, [source_process], [False]] ) for target_process in target_process_group: # has_sent means the source_partition_index has been sent to target_process has_sent = [] target_partition_index = Resharder.compute_partition_index( target_process, complete_shape, target_dims_mapping, target_process_shape, target_process_group, ) partition_index_list = [] all_partition_index_list = [] for source_process in source_process_group: source_partition_index = Resharder.compute_partition_index( source_process, complete_shape, source_dims_mapping, source_process_shape, source_process_group, ) to_send_process = None if ( all( _ for _ in list( map( self.is_overlapped, source_partition_index, target_partition_index, ) ) ) and source_partition_index not in has_sent ): idx = list( [item[0] for item in partition_process_mapping_list] ).index(source_partition_index) has_used = list( [item[2] for item in partition_process_mapping_list] )[idx] process_list = list( [item[1] for item in partition_process_mapping_list] )[idx] i = 0 while i < len(has_used): if not has_used[i]: to_send_process = process_list[i] has_used[i] = True break i += 1 if i == len(has_used): has_used = list(map(lambda x: False, has_used)) to_send_process = process_list[0] has_used[0] = True assert ( to_send_process is not None ), "Failed to find the send process." if to_send_process not in op_desc_seq.keys(): op_desc_seq[to_send_process] = [] if target_process not in op_desc_seq.keys(): op_desc_seq[target_process] = [] all_partition_index_list.append(source_partition_index) # append send and recv op desc is_bool = dist_tensor.serial_tensor.dtype == paddle.bool send_op_desc = SendOpDesc( source_partition_index, to_send_process, target_process, is_bool=is_bool, ) recv_op_desc = RecvOpDesc( source_partition_index, to_send_process, target_process, is_bool=is_bool, ) op_desc_seq[to_send_process].append(send_op_desc) op_desc_seq[target_process].append(recv_op_desc) has_sent.append(source_partition_index) Resharder.concat_partitions( partition_index_list, source_partition_index ) # append concat op desc op_desc_seq[target_process].append( ConcatOpDesc(all_partition_index_list) ) # append slice op desc slice_starts = [] slice_ends = [] slices_axes = [] concatenated_partition_index = partition_index_list[0] to_slice_tensor_shape = [] for idx, item in enumerate(concatenated_partition_index): slice_starts.append( target_partition_index[idx][0] - item[0] ) slice_ends.append(target_partition_index[idx][1] - item[0]) slices_axes.append(idx) to_slice_tensor_shape.append(item[1] - item[0]) op_desc_seq[target_process].append( SliceOpDesc( slice_starts, slice_ends, slices_axes, shape=to_slice_tensor_shape, ) ) # In the same process group, it will use allgahther and slice op. else: # NOTE: It just supports even partition scene. partition_index_list = [] all_partition_index_list = [] process_index = [] for source_process in source_process_group: source_partition_index = Resharder.compute_partition_index( source_process, complete_shape, source_dims_mapping, source_process_shape, source_process_group, ) if source_partition_index not in partition_index_list: partition_index_list.append(source_partition_index) process_index.append( [ [ source_process, ], source_partition_index, ] ) else: process_index[ partition_index_list.index(source_partition_index) ][0].append(source_process) for i in range(len(process_index[0][0])): group = [] for j in range(len(process_index)): group.append(process_index[j][0][i]) if i == 0: all_partition_index_list.append(process_index[j][1]) for process in group: # append slice op desc slice_starts = [] slice_ends = [] slices_axes = [] target_partition_index = Resharder.compute_partition_index( process, complete_shape, target_dims_mapping, target_process_shape, target_process_group, ) for idx, item in enumerate(target_partition_index): slice_starts.append(item[0]) slice_ends.append(item[1]) slices_axes.append(idx) to_slice_tensor_shape = dist_tensor.global_sizes() slice_op_desc = SliceOpDesc( starts=slice_starts, ends=slice_ends, axes=slices_axes, shape=to_slice_tensor_shape, ) allgather_shape = ( None if not serial else dist_tensor.local_sizes(rank=process) ) # c_concat pass if ( target_dims_mapping.count(-1) == len(target_dims_mapping) and source_dims_mapping[:-1].count(-1) == len(source_dims_mapping[:-1]) and source_dims_mapping[-1] != -1 ): op_desc_seq[process] = [ AllGatherConcatOpDesc( group=group, shape=allgather_shape ) ] else: op_desc_seq[process] = ( [ AllGatherOpDesc( group=group, shape=allgather_shape, is_bool=( source_tensor.dtype == paddle.bool ), ), ConcatOpDesc( partition_index_list=all_partition_index_list ), slice_op_desc, ] if len(group) > 1 else [slice_op_desc] ) return op_desc_seq def parse_op_desc( self, block, op_desc_seq, var_name, reshard_op, dist_attr ): """Parse op desc sequence and insert op in the block""" tensor_list = [] partition_tensor_list = [] if self.rank_id not in op_desc_seq.keys(): return op_desc_list = op_desc_seq[self.rank_id] idx = None for index, op in list(enumerate(block.ops)): if op.desc.id == reshard_op.desc.id: idx = index break assert ( idx is not None ), "The op for reshard cannot be found in the rank {} program.".format( self.rank_id ) matched_op = block.ops[idx] source_tensor = get_var_with_recursion( var_name, block, self.auto_parallel_main_prog ) for op_desc in op_desc_list: if isinstance(op_desc, AllGatherOpDesc): # noqa: F401 if var_name not in self.has_allgather.keys(): self.has_allgather[var_name] = [] if not self.has_allgather[ var_name ] or op_desc.group not in list( map(lambda x: x[0], self.has_allgather[var_name]) ): if op_desc.is_bool: # for bool data allgather, cast to int64 -> allgather -> cast bool out_cast = Inserter.insert_cast_op( block, idx, source_tensor, reshard_op.attr('op_role'), paddle.int64, ) tensor_list, idx_offset = Inserter.insert_allgather_op( block, idx + 1, out_cast, op_desc.group, reshard_op.attr('op_role'), ) idx += idx_offset tensor_name_list = [] for var in tensor_list: out_cast = Inserter.insert_cast_op( block, idx, var, reshard_op.attr('op_role'), paddle.bool, ) tensor_name_list.append(out_cast.name) idx += 1 self.has_allgather[var_name].append( [op_desc.group, tensor_name_list] ) else: tensor_list, idx_offset = Inserter.insert_allgather_op( block, idx, source_tensor, op_desc.group, reshard_op.attr('op_role'), ) idx += idx_offset tensor_name_list = [var.name for var in tensor_list] self.has_allgather[var_name].append( [op_desc.group, tensor_name_list] ) else: for item in self.has_allgather[var_name]: if op_desc.group == item[0]: tensor_list = [ get_var_with_recursion( var_name, block, self.auto_parallel_main_prog, ) for var_name in item[1] ] break assert ( tensor_list ), "The result of parsing allgather op should not be None." elif isinstance(op_desc, SendOpDesc): if var_name not in self.has_sent.keys(): self.has_sent[var_name] = [] if op_desc.dst not in self.has_sent[var_name]: if op_desc.is_bool: out_cast = Inserter.insert_cast_op( block, idx, source_tensor, reshard_op.attr('op_role'), paddle.int64, ) Inserter.insert_send_op( block, idx + 1, out_cast, op_desc.src, op_desc.dst, reshard_op.attr('op_role'), ) idx += 2 else: Inserter.insert_send_op( block, idx, source_tensor, op_desc.src, op_desc.dst, reshard_op.attr('op_role'), ) idx += 1 self.has_sent[var_name].append(op_desc.dst) elif isinstance(op_desc, RecvOpDesc): if var_name not in self.has_recv.keys(): self.has_recv[var_name] = {} if op_desc.src not in self.has_recv[var_name].keys(): partition_index = op_desc.partition_index shape = [] for index in partition_index: shape.append(index[1] - index[0]) if op_desc.is_bool: # for bool data, recv int64 -> cast to bool recv_tensor = block.create_var( name=unique_name.generate(var_name + "@recv"), shape=shape, lod_level=source_tensor.lod_level, dtype=paddle.int64, type=source_tensor.type, ) Inserter.insert_recv_op( block, idx, recv_tensor, op_desc.src, op_desc.dst, reshard_op.attr('op_role'), ) out_cast = Inserter.insert_cast_op( block, idx + 1, recv_tensor, reshard_op.attr('op_role'), paddle.bool, ) tensor_list.append(out_cast) idx += 2 self.has_recv[var_name][op_desc.src] = out_cast else: recv_tensor = block.create_var( name=unique_name.generate(var_name + "@recv"), shape=shape, lod_level=source_tensor.lod_level, dtype=source_tensor.dtype, type=source_tensor.type, ) Inserter.insert_recv_op( block, idx, recv_tensor, op_desc.src, op_desc.dst, reshard_op.attr('op_role'), ) # for lod tensor, need reset lod after received if recv_tensor.lod_level != 0: set_lod = False # use data lod to reset tensor lod for ( tmp_block ) in self.auto_parallel_main_prog.blocks: for tmp_var_name in tmp_block.vars: tmp_var = tmp_block.vars[tmp_var_name] if ( tmp_var.is_data and tmp_var.lod_level == recv_tensor.lod_level ): reset_lod_out = ( Inserter.insert_reset_lod_op( block, idx + 1, recv_tensor, tmp_var, reshard_op.attr('op_role'), ) ) tensor_list.append(reset_lod_out) idx += 2 self.has_recv[var_name][ op_desc.src ] = reset_lod_out set_lod = True break if set_lod: break assert set_lod is True else: tensor_list.append(recv_tensor) idx += 1 self.has_recv[var_name][op_desc.src] = recv_tensor else: tensor_list.append(self.has_recv[var_name][op_desc.src]) elif isinstance(op_desc, ConcatOpDesc): partition_index_list = op_desc.partition_index_list idx_list = [idx] for index, tensor in enumerate(tensor_list): Inserter.concat_partitions_with_op( partition_tensor_list, tensor, partition_index_list[index], block, idx_list, reshard_op.attr('op_role'), ) idx = idx_list[0] elif isinstance(op_desc, SliceOpDesc) or isinstance( op_desc, AllGatherConcatOpDesc ): target_tensor = None if isinstance(op_desc, SliceOpDesc): assert ( len(partition_tensor_list) == 1 or not partition_tensor_list ) to_slice_tensor = ( partition_tensor_list[0][0] if len(partition_tensor_list) == 1 else source_tensor ) new_name = unique_name.generate(var_name + "@RESHARD") target_tensor = Inserter.insert_slice_op( block, idx, to_slice_tensor, starts=op_desc.starts, ends=op_desc.ends, axes=op_desc.axes, new_var_name=new_name, op_role=reshard_op.attr('op_role'), ) else: target_tensor = Inserter.insert_c_concat_op( block, idx, source_tensor, op_desc.group, reshard_op.attr('op_role'), ) assert target_tensor is not None process_mesh = dist_attr[0] dims_mapping = dist_attr[1] tensor_attr = TensorDistAttr() tensor_attr.dims_mapping = dims_mapping tensor_attr.process_mesh = process_mesh self.dist_context.set_tensor_dist_attr_for_program( target_tensor, tensor_attr ) if matched_op.type == "while": # var_reshard_mapping means the while op input need be changed to if ( "var_reshard_mapping" not in Resharder.while_block_info[ op.attr("sub_block").id ].keys() ): Resharder.while_block_info[op.attr("sub_block").id][ "var_reshard_mapping" ] = {} if ( var_name not in Resharder.while_block_info[ op.attr("sub_block").id ]["var_reshard_mapping"].keys() ): Resharder.while_block_info[op.attr("sub_block").id][ "var_reshard_mapping" ][var_name] = [] Resharder.while_block_info[op.attr("sub_block").id][ "var_reshard_mapping" ][var_name].append([dist_attr, target_tensor.name]) # rename op input name according to new name for op in block.ops: # just for while op while_op_X_append = [] for name in op.input_arg_names: op_dist_attr = ( self.dist_context.get_op_dist_attr_for_program(op) ) if name == var_name and op_dist_attr is not None: if op.desc.id() == matched_op.desc.id(): if matched_op.type == "while": op.desc._rename_input( name, target_tensor.name ) old_name = name new_name = target_tensor.name assert old_name != new_name op_input_dist_attr = ( op_dist_attr.get_input_dist_attr( old_name ) ) op_dist_attr.set_input_dist_attr( new_name, op_input_dist_attr ) op_dist_attr.set_input_dims_mapping( new_name, dims_mapping ) # if ( # old_name # in op_dist_attr._inputs_dist_attrs # ): # op_dist_attr.del_input_dist_attr( # old_name # ) op_dist_attr.set_input_dims_mapping( new_name, dims_mapping ) while_op_X_append.append(new_name) continue else: op.desc._rename_input( name, target_tensor.name ) old_name = name new_name = target_tensor.name assert old_name != new_name op_input_dist_attr = ( op_dist_attr.get_input_dist_attr( old_name ) ) op_dist_attr.set_input_dist_attr( new_name, op_input_dist_attr ) op_dist_attr.set_input_dims_mapping( new_name, dims_mapping ) # op_dist_attr.del_input_dist_attr(old_name) op_dist_attr.set_input_dims_mapping( new_name, dims_mapping ) continue op_process_mesh = op_dist_attr.process_mesh op_input_dims_mapping = ( op_dist_attr.get_input_dims_mapping(var_name) ) # NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation. if ( op_process_mesh == process_mesh and op_input_dims_mapping == dims_mapping ): op.desc._rename_input(name, target_tensor.name) old_name = name new_name = target_tensor.name assert old_name != new_name op_input_dist_attr = ( op_dist_attr.get_input_dist_attr(old_name) ) op_dist_attr.set_input_dist_attr( new_name, op_input_dist_attr ) op_dist_attr.set_input_dims_mapping( new_name, dims_mapping ) # op_dist_attr.del_input_dist_attr(old_name) op_dist_attr.set_input_dims_mapping( new_name, dims_mapping ) # for while op, the input X should reset if while_op_X_append: proto = OpProtoHolder.instance().get_op_proto(op.type) op.desc.set_input( proto.inputs[0].name, op.input("X") + while_op_X_append, ) def _get_subblock_input_attrs(self, op, var_name): # NOTE: Multi while loop is not supported assert op.type in _g_subblock_ops sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id] ops = sub_block.ops input_attrs = [] for op in ops: dist_op = self.dist_context.get_dist_op_for_program(op) if not dist_op: continue dist_attr = dist_op.dist_attr for name in op.input_arg_names: if name == var_name: process_mesh = dist_attr.process_mesh input_dims_mapping = dist_attr.get_input_dims_mapping( var_name ) has_exist = False for input_attr in input_attrs: if ( process_mesh == input_attr[0] and input_dims_mapping == input_attr[1] ): has_exist = True break if not has_exist: input_attrs.append([process_mesh, input_dims_mapping]) return input_attrs def _get_subblock_output_attrs(self, op, var_name): # NOTE: Multi while loop is not supported assert op.type in _g_subblock_ops sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id] ops = sub_block.ops output_attrs = [] for op in ops: dist_op = self.dist_context.get_dist_op_for_program(op) if not dist_op: continue dist_attr = dist_op.dist_attr for name in op.output_arg_names: if name == var_name: process_mesh = dist_attr.process_mesh output_dims_mapping = dist_attr.get_output_dims_mapping( var_name ) has_exist = False for output_attr in output_attrs: if ( process_mesh == output_attrs[0] and output_dims_mapping == output_attrs[1] ): has_exist = True break if not has_exist: output_attrs.append([process_mesh, output_dims_mapping]) return output_attrs def _get_common_op_input_attrs(self, op, var_name): process_meshes = [] dist_op = self.dist_context.get_dist_op_for_program(op) dist_attr = dist_op.dist_attr op_process_mesh = dist_attr.process_mesh for process_mesh in self.dist_context.process_meshes: if set(process_mesh.process_ids) & ( set(op_process_mesh.process_ids) ) and len(process_mesh.process_ids) < len( op_process_mesh.process_ids ): process_meshes.append(process_mesh) # it means that the process mesh is not a union when process meshes is none if not process_meshes: process_meshes.append(op_process_mesh) input_dims_mapping = dist_attr.get_input_dims_mapping(var_name) input_attrs = [] for process_mesh in process_meshes: input_attrs.append([process_mesh, input_dims_mapping]) return input_attrs def get_op_input_attrs(self, op, var_name): op_input_attrs = [] if op.type in _g_subblock_ops: op_input_attrs = self._get_subblock_input_attrs(op, var_name) if not op_input_attrs: # NOTE: [hack method] # Adapt to quantization pass, which presist_vars, including inputs and outputs, all are in global_block. # Therefore, the while_op's inputs will contain the all persist_vars, which will be inputs or output of the quantization op in subblock. op_input_attrs = self._get_subblock_output_attrs(op, var_name) else: op_input_attrs = self._get_common_op_input_attrs(op, var_name) assert op_input_attrs return op_input_attrs def _remove_global_process_mesh(self): """Remove global process mesh from dist_context.process_meshes""" processes = set() process_mesh_count = len(self.dist_context.process_meshes) if process_mesh_count > 1: global_process_mesh_idx = None for process_mesh in self.dist_context.process_meshes: for process in process_mesh.process_ids: processes.add(process) for idx, process_mesh in enumerate( self.dist_context.process_meshes ): if len(set(process_mesh.process_ids)) == len(processes): global_process_mesh_idx = idx break if global_process_mesh_idx is not None: is_removed = False global_mesh = self.dist_context.process_meshes[idx] for i, mesh in enumerate(self.dist_context.process_meshes): if i == idx: continue if set(mesh.process_ids) < set(global_mesh.process_ids): is_removed = True if is_removed: self.dist_context.process_meshes.pop(idx) def _change_subblock_op_input_and_output(self, block_idx, block): if "var_reshard_mapping" in Resharder.while_block_info[block_idx]: var_reshard_mapping = Resharder.while_block_info[block_idx][ "var_reshard_mapping" ] for op in block.ops: for var_name in op.input_arg_names: if var_name in var_reshard_mapping: # in while sub block, the union process mesh is not split before reshard sub block dist_op = self.dist_context.get_dist_op_for_program(op) dist_attr = dist_op.dist_attr target_name = None for item in var_reshard_mapping[var_name]: if ( dist_attr.process_mesh == item[0][0] and dist_attr.get_input_dims_mapping(var_name) == item[0][1] ): target_name = item[1] break if target_name is None: continue else: op.desc._rename_input(var_name, target_name) dist_op = self.dist_context.get_dist_op_for_program( op ) op_dist_attr = dist_op.dist_attr old_name = var_name new_name = target_name assert old_name != new_name op_input_dist_attr = ( op_dist_attr.get_input_dist_attr(old_name) ) op_dist_attr.set_input_dist_attr( new_name, op_input_dist_attr ) # op_dist_attr.del_input_dist_attr(old_name) # the outputs also need to be renamed when the output name is the same with input name in inplace op for var_name in op.output_arg_names: # if the tensor has been resharded multiply, it is not supported now. if var_name in var_reshard_mapping: if len(var_reshard_mapping[var_name]) > 1: raise ValueError( "The scene is not supported that the output is inplaced and the tensor has been resharded multiply when as input." ) target_name = var_reshard_mapping[var_name][0][1] op.desc._rename_output(var_name, target_name) dist_op = self.dist_context.get_dist_op_for_program(op) op_dist_attr = dist_op.dist_attr old_name = var_name new_name = target_name assert old_name != new_name op_output_dist_attr = op_dist_attr.get_output_dist_attr( old_name ) op_dist_attr.set_output_dist_attr( new_name, op_output_dist_attr ) # op_dist_attr.del_output_dist_attr(old_name) def _reshard_input(self, block): idx = 0 while idx < len(block.ops): pre_op_count = len(block.ops) op = block.ops[idx] if self.is_special_op(op): idx += 1 continue dist_op = self.dist_context.get_dist_op_for_program(op) if dist_op is not None: op_input_dist_attrs = ( [] ) # [(op_process_mesh, op_input_dims_mapping), (op_process_mesh, op_input_dims_mapping)] if op.type in _g_subblock_ops: if not self.is_condition_replicative(op): raise ValueError( "Please check the condition due to the dims mapping is not replicative." ) if ( op.attr("sub_block").id not in Resharder.while_block_info ): Resharder.while_block_info[op.attr("sub_block").id] = {} Resharder.while_block_info[op.attr("sub_block").id][ "op_id" ] = op.desc.id() if op.type == "while": # condition var process mesh is the same with op and dims_mapping is replicative, so it do not need reshard input_var_names = op.input("X") elif op.type == "conditional_block": input_var_names = op.input("Input") else: input_var_names = op.input_arg_names # to avoid while op X order different input_var_names.sort() idx_offset = 0 for var_name in input_var_names: # skip lod_tensor_blocking_queue_? name if "lod_tensor_blocking_queue" in var_name: continue var = get_var_with_recursion( var_name, block, self.auto_parallel_main_prog ) dist_tensor = self.dist_context.get_dist_tensor_for_program( var ) # judge whether union tensor dims_mapping all -1 is_union_process_mesh_tensor = False if ( dist_tensor.dist_attr.process_mesh not in self.dist_context.process_meshes and self.dist_context.process_meshes ): is_union_process_mesh_tensor = True assert dist_tensor.dist_attr.dims_mapping.count( -1 ) == len(dist_tensor.dist_attr.dims_mapping) op_input_attrs = self.get_op_input_attrs(op, var_name) for input_attr in op_input_attrs: input_process_mesh = None # deal with union tensor if is_union_process_mesh_tensor: # if op process mesh is subset of union tensor process mesh, need no reshard if set(input_attr[0].process_ids) <= set( dist_tensor.dist_attr.process_mesh.process_ids ): continue if dist_tensor is not None and self.need_reshard( dist_tensor, input_attr ): reshard_op_desc = self.find_op_desc_seq( dist_tensor, input_attr ) self.parse_op_desc( block, reshard_op_desc, var_name, op, input_attr ) cur_op_count = len(block.ops) idx_offset = ( idx_offset + cur_op_count - pre_op_count ) pre_op_count = cur_op_count idx = idx + idx_offset + 1 else: idx += 1 def _hadnle_recv(self, block, idx, var, op, send_rank, recv_rank): if self.rank_id == recv_rank: # if recv bool data, recv then cast if var.dtype == paddle.bool: recv_cast_out = block.create_var( name=unique_name.generate(var.name + "@recv"), shape=var.shape, lod_level=var.lod_level, dtype=paddle.int64, type=var.type, ) Inserter.insert_recv_op( block, idx + 1, recv_cast_out, send_rank, recv_rank, op.attr('op_role'), ) reset_lod_out = None if var.lod_level != 0: set_lod = False for tmp_block in self.auto_parallel_main_prog.blocks: for tmp_var_name in tmp_block.vars: tmp_var = tmp_block.vars[tmp_var_name] if ( tmp_var.is_data and tmp_var.lod_level == var.lod_level ): reset_lod_out = block.create_var( name=unique_name.generate( var.name + "@RESETLOD" ), shape=recv_cast_out.shape, type=recv_cast_out.type, dtype=recv_cast_out.dtype, lod_level=recv_cast_out.lod_level, ) idx += 1 block._insert_op( idx, type="lod_reset", inputs={'X': recv_cast_out, 'Y': tmp_var}, outputs={'Out': reset_lod_out}, attrs={'op_role': op.attr("op_role")}, ) set_lod = True break if set_lod: break assert set_lod is True # cast int64 to bool block._insert_op( idx + 2, type='cast', inputs={ 'X': [recv_cast_out] if reset_lod_out is None else [reset_lod_out] }, outputs={'Out': [var]}, attrs={ 'in_dtype': recv_cast_out.dtype, 'out_dtype': var.dtype, 'op_role': op.attr('op_role'), }, ) else: if var.lod_level != 0: recv_out = block.create_var( name=unique_name.generate(var.name + "@recv"), shape=var.shape, lod_level=var.lod_level, dtype=var.int64, type=var.type, ) Inserter.insert_recv_op( block, idx + 1, recv_out, send_rank, recv_rank, op.attr('op_role'), ) set_lod = False for tmp_block in self.auto_parallel_main_prog.blocks: for tmp_var_name in tmp_block.vars: tmp_var = tmp_block.vars[tmp_var_name] if ( tmp_var.is_data and tmp_var.lod_level == var.lod_level ): idx += 1 block._insert_op( idx, type="lod_reset", inputs={'X': recv_out, 'Y': tmp_var}, outputs={'Out': var}, attrs={'op_role': op.attr("op_role")}, ) set_lod = True break if set_lod: break assert set_lod is True else: Inserter.insert_recv_op( block, idx + 1, var, send_rank, recv_rank, op.attr('op_role'), ) def _handle_send(self, block, idx, var, op, send_rank, recv_rank): if var.dtype == paddle.bool: cast_out = Inserter.insert_cast_op( block, idx + 1, var, op.attr('op_role'), paddle.int64 ) Inserter.insert_send_op( block, idx + 2, cast_out, send_rank, recv_rank, op.attr('op_role'), ) else: Inserter.insert_send_op( block, idx + 1, var, send_rank, recv_rank, op.attr('op_role') ) def _reshard_output(self, block): # insert send and recv op if output process mesh is different from tensor process mesh idx = 0 # skip reader and ops whose process mesh is union skip_ops = [ "create_py_reader", "create_double_buffer_reader", "read", "write_to_array", "read_from_array", "nop", "depend", ] global _g_special_ops skip_ops += _g_special_ops skip_ops += _g_subblock_ops while idx < len(block.ops): pre_op_count = len(block.ops) op = block.ops[idx] dist_op = self.dist_context.get_dist_op_for_program(op) if dist_op is not None and op.type not in skip_ops: idx_offset = 0 for var_name in op.output_arg_names: var = get_var_with_recursion( var_name, block, self.auto_parallel_main_prog ) dist_tensor = self.dist_context.get_dist_tensor_for_program( var ) tensor_process_mesh = dist_tensor.dist_attr.process_mesh output_attr = [ dist_op.dist_attr.process_mesh, dist_op.dist_attr.get_output_dims_mapping(var_name), ] if dist_tensor is not None and self.need_reshard( dist_tensor, output_attr, False ): tensor_processes = set( tensor_process_mesh.process_ids ) - ( set(tensor_process_mesh.process_ids) & set(output_attr[0].process_ids) ) if tensor_processes: if len(tensor_processes) != len( output_attr[0].process_ids ): if dist_tensor.dist_attr.dims_mapping.count( -1 ) != len( dist_tensor.dist_attr.dims_mapping ) or output_attr[ 1 ].count( -1 ) != len( output_attr[1] ): raise ValueError( "The dims_mapping must be -1" ) else: for index, tensor_process in enumerate( tensor_processes ): recv_rank = tensor_process actual_index = index if index >= len( output_attr[0].process_ids ): actual_index = ( index - len( output_attr[0].process_ids ) ) % len(output_attr[0].process_ids) item = output_attr[0].process_ids[ actual_index ] if recv_rank == item: continue if self.rank_id == item: # if send bool data, cast then send self._handle_send( block, idx, var, op, item, recv_rank, ) if self.rank_id == recv_rank: # if recv bool data, recv then cast self._hadnle_recv( block, idx, var, op, item, recv_rank, ) else: for index, tensor_process in enumerate( tensor_processes ): recv_rank = tensor_process item = output_attr[0].process_ids[index] if recv_rank == item: continue if self.rank_id == item: # if send bool data, cast then send self._handle_send( block, idx, var, op, item, recv_rank ) if self.rank_id == recv_rank: # if recv bool data, recv then cast self._hadnle_recv( block, idx, var, op, item, recv_rank ) cur_op_count = len(block.ops) idx_offset = ( idx_offset + cur_op_count - pre_op_count ) pre_op_count = cur_op_count idx = idx + idx_offset + 1 else: idx += 1 def reshard(self): self._remove_global_process_mesh() for block_idx, block in enumerate(self.auto_parallel_main_prog.blocks): # change the var_name before resharding sub block if block_idx in Resharder.while_block_info: self._change_subblock_op_input_and_output(block_idx, block) # reshard input self._reshard_input(block) # reshard output # NOTE: Only support that insert send and recv op if output process mesh is different from tensor process mesh self._reshard_output(block) # remove no need vars and ops in the main program Remover.remove_no_need_in_main( self.auto_parallel_main_prog, self.dist_context, self.rank_id, self.dist_params_grads, ) # remove no need vars and ops in the startip program Remover.remove_no_need_in_startup( self.auto_parallel_main_prog, self.auto_parallel_startup_prog ) # reset some variable when remove operation ended Resharder.while_block_info = {} def get_cost(self, op, tensor, cluster): # NOTE: The program should be the serial_program which is not been parted global _g_special_ops not_supported_op_type = _g_special_ops + ["while"] reshard_op_cost = None if op.type in not_supported_op_type: return reshard_op_cost else: tensor_name = tensor.name if tensor_name == "lod_tensor_blocking_queue_0": return reshard_op_cost else: dist_tensor = self.dist_context.get_dist_tensor_for_program( tensor ) # simplified processing: ignore union process mesh and output reshard dist_op = self.dist_context.get_dist_op_for_program(op) dims_mapping = dist_op.dist_attr.get_input_dims_mapping( tensor.name ) process_mesh = dist_op.dist_attr.process_mesh dist_attr = [process_mesh, dims_mapping] if dist_tensor is not None and self.need_reshard( dist_tensor, dist_attr ): if tensor_name not in self._has_resharded: self._has_resharded[tensor_name] = [dist_op] else: for item in self._has_resharded[tensor_name]: item_dist_attr = item.dist_attr item_dims_mapping = ( item_dist_attr.get_input_dims_mapping( tensor_name ) ) item_process_mesh = item_dist_attr.process_mesh if ( dims_mapping == item_dims_mapping and item_process_mesh == process_mesh ): return reshard_op_cost self._has_resharded[tensor_name].append(dist_op) reshard_op_desc = self.find_op_desc_seq( dist_tensor, dist_attr, serial=True ) dtype = dist_tensor.serial_tensor.dtype reshard_op_cost = self.parse_op_desc_for_cost( reshard_op_desc, dtype, cluster ) return reshard_op_cost def _concat_partitions_for_cost( self, partition_tensor_list, partition_index, dtype, rank_id, local_rank_comp_cost, cluster, ): if not partition_tensor_list: partition_tensor_list.append(partition_index) else: i = 0 has_concat = False while i < len(partition_tensor_list): ( concat_axis, first_order, new_partition, ) = Resharder.compute_concat_info( partition_tensor_list[i], partition_index ) if concat_axis != -1: has_concat = True concat_desc = {} concat_desc["op"] = "concat" concat_desc["attrs"] = {"axis": concat_axis} if first_order == 0: concat_desc["inputs"] = { "X": [ (dtype, partition_tensor_list[i]), (dtype, partition_index), ] } else: concat_desc["inputs"] = { "X": [ (dtype, partition_index), (dtype, partition_tensor_list[i]), ] } partition_tensor_list.pop(i) if rank_id not in local_rank_comp_cost: local_rank_comp_cost[rank_id] = [] local_rank_comp_cost[rank_id].append( ConcatOpCost(op_desc=concat_desc, cluster=cluster) ) self._concat_partitions_for_cost( partition_tensor_list, new_partition, dtype, rank_id, local_rank_comp_cost, cluster, ) break i += 1 if not has_concat: partition_tensor_list.append(partition_index) def parse_op_desc_for_cost(self, reshard_op_desc, dtype, cluster): def _get_idx(comm_ranks, group_ranks): res, is_the_same = None, False idx = 0 while idx < len(comm_ranks): if comm_ranks[idx] == set(group_ranks): is_the_same = True for rank in group_ranks: if rank in comm_ranks[idx]: res = idx comm_ranks[idx].add(rank) if res is None: idx += 1 else: break return res, is_the_same comm_context = CommContext(cluster) # run communication op before computation op # TODO: Communication cost is not calculated when the var has been transfered by the same group in the past comm_costs = [] comm_ranks = [] local_rank_comp_cost = {} for key in reshard_op_desc: partition_tensor_list = [] op_desc_list = reshard_op_desc[key] for op_desc in op_desc_list: if isinstance(op_desc, SendOpDesc): group_ranks = [key, op_desc.dst] shape = op_desc.shape send_desc = build_comm_desc( "send_v2", group_ranks, dtype, shape ) idx, is_the_same = _get_idx(comm_ranks, group_ranks) if idx is None: comm_costs.append( [ ( group_ranks, SendOpCost( op_desc=send_desc, comm_context=comm_context, ), ) ] ) comm_ranks.append(set(group_ranks)) else: if not is_the_same: comm_costs[idx].append( ( group_ranks, SendOpCost( op_desc=send_desc, comm_context=comm_context, ), ) ) elif isinstance(op_desc, AllGatherOpDesc): # NOTE: fill_const and other unnecessary op is not calculated because those cost is very small group_ranks = op_desc.group shape = op_desc.shape allgather_desc = build_comm_desc( "c_allgather", group_ranks, dtype, shape ) split_inputs_shape = [] for idx, dim in enumerate(shape): if idx == 0: split_inputs_shape.append(dim * len(group_ranks)) else: split_inputs_shape.append(dim) idx, is_the_same = _get_idx(comm_ranks, group_ranks) if idx is None: comm_costs.append( [ ( group_ranks, AllgatherOpCost( op_desc=allgather_desc, comm_context=comm_context, ), ) ] ) comm_ranks.append(set(group_ranks)) else: if not is_the_same: comm_costs[idx].append( ( group_ranks, AllgatherOpCost( op_desc=allgather_desc, comm_context=comm_context, ), ) ) # calc the split op cost if key not in local_rank_comp_cost: local_rank_comp_cost[key] = [] split_desc = {} split_desc["op"] = "split" split_desc["inputs"] = { "inputs": [(dtype, split_inputs_shape)] } split_desc["attrs"] = {"num": len(group_ranks), "axis": 0} local_rank_comp_cost[key].append( SplitOpCost(op_desc=split_desc, cluster=cluster) ) elif isinstance(op_desc, ConcatOpDesc): partition_index_list = op_desc._partition_index_list for idx, partion_idex in enumerate(partition_index_list): self._concat_partitions_for_cost( partition_tensor_list, partion_idex, dtype, key, local_rank_comp_cost, cluster, ) elif isinstance(op_desc, SliceOpDesc): if key not in local_rank_comp_cost: local_rank_comp_cost[key] = [] assert ( len(partition_tensor_list) == 1 or not partition_tensor_list ) to_slice_tensor_shape = [] if len(partition_tensor_list) == 1: for item in partition_tensor_list[0]: to_slice_tensor_shape.append(item[1] - item[0]) else: to_slice_tensor_shape = op_desc.shape slice_desc = {} slice_desc["op"] = "slice" infer_flags = list(1 for i in range(len(op_desc.axes))) slice_desc["attrs"] = { "axes": op_desc.axes, "starts": op_desc.starts, "ends": op_desc.ends, "infer_flags": infer_flags, } slice_desc["inputs"] = { "Input": [(dtype, to_slice_tensor_shape)] } local_rank_comp_cost[key].append( SliceOpCost(op_desc=slice_desc, cluster=cluster) ) res = (comm_costs, local_rank_comp_cost) return res