From 8624f3b104ea497ae8b2bebfdcea5dd1887e17e7 Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Fri, 12 Aug 2022 19:05:27 +0800 Subject: [PATCH] [Auto Parallel] Update reshard for auto search (#45002) * update reshard for auto search * fix unittest bug * update dist tensor * update reshard output * fix unittests bug * merge develop --- .../auto_parallel/dist_attribute.py | 8 +- .../distributed/auto_parallel/dist_tensor.py | 25 +- .../distributed/auto_parallel/reshard.py | 1258 ++++++++++++----- .../paddle/distributed/auto_parallel/utils.py | 5 +- .../passes/auto_parallel_sharding.py | 5 +- .../unittests/auto_parallel_data_unshard.py | 4 +- .../unittests/test_auto_parallel_reshard.py | 5 +- 7 files changed, 941 insertions(+), 369 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/dist_attribute.py b/python/paddle/distributed/auto_parallel/dist_attribute.py index 9bbc4de6bdd..ff07deb42aa 100644 --- a/python/paddle/distributed/auto_parallel/dist_attribute.py +++ b/python/paddle/distributed/auto_parallel/dist_attribute.py @@ -276,8 +276,8 @@ class OperatorDistributedAttribute: dist_attr_object.init(dist_attr) self._inputs_dist_attrs[name] = dist_attr_object - # def del_input_dist_attr(self, name): - # del self._inputs_dist_attrs[name] + def del_input_dist_attr(self, name): + del self._inputs_dist_attrs[name] def get_output_dist_attr(self, name): return self._outputs_dist_attrs.get(name, None) @@ -287,8 +287,8 @@ class OperatorDistributedAttribute: dist_attr_object.init(dist_attr) self._outputs_dist_attrs[name] = dist_attr_object - # def del_output_dist_attr(self, name): - # del self._inputs_dist_attrs[name] + def del_output_dist_attr(self, name): + del self._outputs_dist_attrs[name] def get_input_dims_mapping(self, name): input_dist_attr = self.get_input_dist_attr(name) diff --git a/python/paddle/distributed/auto_parallel/dist_tensor.py b/python/paddle/distributed/auto_parallel/dist_tensor.py index b6228f5ad0e..59a2d7a5823 100644 --- a/python/paddle/distributed/auto_parallel/dist_tensor.py +++ b/python/paddle/distributed/auto_parallel/dist_tensor.py @@ -163,7 +163,6 @@ class DistributedTensor: self._batch_dim = 0 # Reuse the dist_attr setter to initialize _dist_attr self.dist_attr = dist_attr - self._local_sizes_map = {} self._local_offsets_map = {} self._local_shard_map = {} self._local_tensor_map = {} @@ -223,20 +222,17 @@ class DistributedTensor: return True def local_sizes(self, rank=None): + """Get local sizes of the given rank.""" rank = paddle.distributed.get_rank() if rank is None else rank - local_sizes = None - if rank in self._local_sizes_map.keys(): - local_sizes = self._local_sizes_map[rank] - else: - global_sizes = self.serial_tensor.shape - dims_mapping = self.dist_attr.dims_mapping - shard_sizes = self.dist_attr.shard_sizes - processes = self.dist_attr.process_mesh.processes - topology = self.dist_attr.process_mesh.topology - local_sizes = DistributedTensor.get_local_sizes( - global_sizes, dims_mapping, topology, processes, rank, - shard_sizes) - self._local_sizes_map[rank] = local_sizes + global_sizes = self.serial_tensor.shape + dims_mapping = self.dist_attr.dims_mapping + shard_sizes = self.dist_attr.shard_sizes + processes = self.dist_attr.process_mesh.processes + topology = self.dist_attr.process_mesh.topology + local_sizes = DistributedTensor.get_local_sizes(global_sizes, + dims_mapping, topology, + processes, rank, + shard_sizes) return local_sizes @@ -282,7 +278,6 @@ class DistributedTensor: def new_local_tensor(self, block=None, rank=None, name=None): """ Create a new local tensor of serial tensor corresponding to rank. - Args: block (Block): The block contains the new tensor. Default value is recommend and it will be created in the block of dist main program corresponding to the serial tensor block id. Default: None. rank (int): The rank id. Default value is recommend and it will be the current rank. Default: None. diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index 8fb38142218..6b902d6fb77 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -26,6 +26,11 @@ from ..collective import _get_global_env from .dist_context import DistributedContext from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from .process_group import new_process_group, ProcessGroup, _g_process_group_map +from .cost import build_comm_desc, CommContext +from .cost import AllgatherOpCost, SendOpCost +from .cost import SliceOpCost, SplitOpCost, ConcatOpCost +from .cluster import Cluster +from .utils import print_program_with_dist_attr # NOTE: If op in _g_special_ops, it will not be resharded. _g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling'] @@ -41,6 +46,7 @@ def get_var_with_recursion(var_name, block, program): if var_name in parent_block.vars: var = parent_block.vars[var_name] assert var is not None + return var @@ -50,11 +56,19 @@ class AllGatherOpDesc: Args: group (list): Process group. + shape (list): The tensor shape. + is_bool (bool): Whether allgather bool data. Default: False. """ - def __init__(self, group): + def __init__(self, group, shape, is_bool=False): self._group = group self._desc = "all_gather" + self._shape = shape + self._is_bool = is_bool + + @property + def is_bool(self): + return self._is_bool @property def group(self): @@ -64,8 +78,12 @@ class AllGatherOpDesc: def desc(self): return self._desc + @property + def shape(self): + return self._shape + def __repr__(self): - return f"op: {self._desc}, group: {self._group}." + return f"op: {self._desc}, group: {self._group}, shape: {self._shape}, is_bool: {self._is_bool}." class SendOpDesc: @@ -74,13 +92,26 @@ class SendOpDesc: Args: partition_index (list): The index of partition in complete tensor. + src (int): The source process to send. dst (int): The destination process to receive. + is_bool (bool): Whether send bool data. Default: False. """ - def __init__(self, partition_index, dst): + def __init__(self, partition_index, src, dst, is_bool=False): self._dst = dst self._partition_index = partition_index self._desc = "send" + self._shape = [] + self._is_bool = is_bool + self._src = src + + @property + def src(self): + return self._src + + @property + def is_bool(self): + return self._is_bool @property def partition_index(self): @@ -94,8 +125,15 @@ class SendOpDesc: def desc(self): return self._desc + @property + def shape(self): + if not self._shape: + for item in self.partition_index: + self._shape.append(item[1] - item[0]) + return self._shape + def __repr__(self): - return f"op: {self._desc}, partition_index: {self._partition_index}, dst: {self._dst}." + return f"op: {self._desc}, partition_index: {self._partition_index}, dst: {self._dst}, shape: {self._shape}, is_bool: {self._is_bool}." class RecvOpDesc: @@ -105,12 +143,25 @@ class RecvOpDesc: Args: partition_index (list): The index of partition in complete tensor. src (int): The source process to send. + dst (int): The destination process to receive. + is_bool (bool): Whether receive bool data. Default: False. """ - def __init__(self, partition_index, src): + def __init__(self, partition_index, src, dst, is_bool=False): self._src = src self._partition_index = partition_index self._desc = "recv" + self._shape = [] + self._is_bool = is_bool + self._dst = dst + + @property + def dst(self): + return self._dst + + @property + def is_bool(self): + return self._is_bool @property def partition_index(self): @@ -124,8 +175,15 @@ class RecvOpDesc: def desc(self): return self._desc + @property + def shape(self): + if not self._shape: + for item in self.partition_index: + self._shape.append(item[1] - item[0]) + return self._shape + def __repr__(self): - return f"op: {self._desc}, partition_index: {self._partition_index}, src: {self._src}." + return f"op: {self._desc}, partition_index: {self._partition_index}, dst: {self._dst}, shape: {self._shape}, is_bool: {self._is_bool}." class SliceOpDesc: @@ -133,16 +191,18 @@ class SliceOpDesc: Describe the slice op in the reshard phase. Args: - starts (list): It represents starting indices of corresponding axis in ``axes``. - ends (list): It represents ending indices of corresponding axis in ``axes``. - axes (list): Axes that `starts` and `ends` apply to . + starts (list): It represents start indices of corresponding axis in ``axes``. + ends (list): It represents end indices of corresponding axis in ``axes``. + axes (list): Axes that `starts` and `ends` apply to. + shape (list): The shape of the tensor to be sliced. """ - def __init__(self, starts, ends, axes): + def __init__(self, starts, ends, axes, shape=None): self._starts = starts self._ends = ends self._axes = axes self._desc = "slice" + self._shape = shape @property def starts(self): @@ -160,8 +220,15 @@ class SliceOpDesc: def desc(self): return self._desc + @property + def shape(self): + return self._shape + def __repr__(self): - return f"op: {self._desc}, starts: {self._starts}, ends: {self._ends}, axes: {self._axes}." + if self._shape is not None: + return f"op: {self._desc}, starts: {self._starts}, ends: {self._ends}, axes: {self._axes}, shape: {self._shape}." + else: + return f"op: {self._desc}, starts: {self._starts}, ends: {self._ends}, axes: {self._axes}." class ConcatOpDesc: @@ -192,36 +259,84 @@ class Inserter: """Insert op required in the reshard process.""" @staticmethod - def insert_send_op(block, idx, tensor, dst, op_role): + def insert_cast_op(block, idx, tensor, op_role, tensor_type): + # to avoid name conflict with framework + new_var_name = paddle.fluid.unique_name.generate_with_ignorable_key( + ".".join(["cast@RESHARD", 'tmp'])) + out = block.create_var(name=new_var_name, + dtype=tensor_type, + type=tensor.type, + lod_level=tensor.lod_level) + block._insert_op(idx, + type='cast', + inputs={'X': [tensor]}, + outputs={'Out': [out]}, + attrs={ + 'in_dtype': tensor.dtype, + 'out_dtype': out.dtype, + 'op_role': op_role + }) + return out + + @staticmethod + def insert_send_op(block, idx, tensor, src, dst, op_role): """Insert send op into block at the given index.""" op_type = 'send_v2' + # use pair comm group + process_group = new_process_group([src, dst]) block._insert_op(idx, type=op_type, inputs={'X': [tensor]}, attrs={ - 'ring_id': 0, - 'peer': dst, + 'ring_id': process_group.id, + 'peer': process_group.ranks.index(dst), 'use_calc_stream': True, - 'op_role': op_role + 'op_role': op_role, + 'dynamic_shape': True }) @staticmethod - def insert_recv_op(block, idx, tensor, src, op_role): + def insert_recv_op(block, idx, tensor, src, dst, op_role): """Insert recv op into block at the given index.""" op_type = 'recv_v2' + # use pair group + process_group = new_process_group([src, dst]) block._insert_op(idx, type=op_type, inputs={'X': [tensor]}, outputs={'Out': [tensor]}, attrs={ - 'ring_id': 0, - 'peer': src, + 'ring_id': process_group.id, + 'peer': process_group.ranks.index(src), 'out_shape': tensor.shape, 'dtype': tensor.dtype, 'use_calc_stream': True, - 'op_role': op_role + 'op_role': op_role, + 'dynamic_shape': True }) + @staticmethod + def insert_reset_lod_op(block, idx, X, Y, op_role): + """Insert reset_lod op into block at the given index.""" + + new_var_name = paddle.fluid.unique_name.generate_with_ignorable_key( + ".".join(["reset_lod@RESHARD", 'tmp'])) + reset_lod_out = block.create_var(name=new_var_name, + shape=X.shape, + type=X.type, + dtype=X.dtype, + lod_level=X.lod_level) + + block._insert_op(idx, + type="lod_reset", + inputs={ + 'X': X, + 'Y': Y + }, + outputs={'Out': reset_lod_out}, + attrs={'op_role': op_role}) + return reset_lod_out + @staticmethod def insert_concat_op(block, idx, tensors, axis, op_role): """Insert concat op into block at the given block.""" @@ -229,10 +344,18 @@ class Inserter: attrs = {} attrs['axis'] = axis attrs['op_role'] = op_role - helper = LayerHelper('concat', **locals()) + # to avoid name conflict with framework + helper = LayerHelper('concat@RESHARD', **locals()) with paddle.static.program_guard(block.program): - out = helper.create_variable_for_type_inference( - dtype=helper.input_dtype()) + out = block.create_var( + name=paddle.fluid.unique_name.generate_with_ignorable_key( + ".".join([helper.name, 'tmp'])), + dtype=tensors[0].dtype, + shape=None, + lod_level=tensors[0].lod_level, + type=tensors[0].type, + persistable=False, + stop_gradient=False) block._insert_op(idx, type='concat', inputs=inputs, @@ -244,37 +367,117 @@ class Inserter: def insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name, op_role): """Insert slice op into block at the given block.""" - inputs = {'Input': tensor} - infer_flags = list(1 for i in range(len(axes))) - attrs = { - "axes": axes, - "starts": starts, - "ends": ends, - "infer_flags": infer_flags, - 'op_role': op_role - } - helper = LayerHelper('slice', **locals()) - out = block.create_var(name=new_var_name, - dtype=tensor.dtype, - type=tensor.type) - block._insert_op(idx, - type="slice", - inputs=inputs, - outputs={'Out': [out]}, - attrs=attrs) - return out + # This is a hack to insert split op to get slice tensor + # 1. [128, 128] => [64, 128]: split + # 2. [128, 128] => [128, 128]: assign + # 3. [128, 128] => [64, 64]: slice, it will replaced by multi split + global_shape = tensor.shape + slice_shape = [ends[i] - starts[i] for i in range(len(starts))] + diff_dims = [] + for index, item in enumerate(slice_shape): + if item != global_shape[index]: + diff_dims.append(index) + + # use assign + if len(diff_dims) == 0: + out = block.create_var(name=new_var_name, + dtype=tensor.dtype, + type=tensor.type, + shape=slice_shape, + lod_level=tensor.lod_level) + inputs = {'X': [tensor]} + outputs = {"Out": [out]} + attrs = {"in_place": False} + block._insert_op(idx, + type="assign", + inputs=inputs, + outputs=outputs, + attrs=attrs) + return out + + # use split once + elif len(diff_dims) == 1: + diff_dim = diff_dims[0] + num_or_sections = global_shape[diff_dim] // slice_shape[diff_dim] + axis = diff_dim + cur_idx = starts[diff_dim] // slice_shape[diff_dim] + input_shape = global_shape + inputs = {'X': tensor} + attrs = {'num': num_or_sections, 'axis': axis, 'op_role': op_role} + new_shape = [] + for index, item in enumerate(tensor.shape): + if index != axis: + new_shape.append(item) + else: + new_shape.append(item // num_or_sections) + with paddle.static.program_guard(block.program): + outs = [ + block.create_var(name=paddle.fluid.unique_name. + generate_with_ignorable_key(".".join( + ['split@RESHARD', 'tmp'])), + dtype=tensor.dtype, + shape=None, + type=tensor.type, + persistable=False, + lod_level=tensor.lod_level, + stop_gradient=False) + for i in range(num_or_sections) + ] + out = outs[cur_idx] + op = block._insert_op(idx, + type="split", + inputs=inputs, + outputs={'Out': outs}, + attrs=attrs) + return out + + # use slice + else: + inputs = {'Input': tensor} + infer_flags = list(1 for i in range(len(axes))) + attrs = { + "axes": axes, + "starts": starts, + "ends": ends, + "infer_flags": infer_flags, + 'op_role': op_role + } + out = block.create_var(name=new_var_name, + dtype=tensor.dtype, + type=tensor.type, + lod_level=tensor.lod_level) + block._insert_op(idx, + type="slice", + inputs=inputs, + outputs={'Out': [out]}, + attrs=attrs) + + return out @staticmethod - def insert_split_op(block, idx, tensor, num_or_sections, op_role): + def insert_split_op(block, idx, tensor, num_or_sections, op_role, axis=0): """Insert split op into block at the given index.""" - helper = LayerHelper('split', **locals()) + helper = LayerHelper('split@RESHARD', **locals()) input_shape = tensor.shape inputs = {'X': tensor} - attrs = {'num': num_or_sections, 'axis': 0, 'op_role': op_role} + attrs = {'num': num_or_sections, 'axis': axis, 'op_role': op_role} + new_shape = [] + for index, item in enumerate(tensor.shape): + if index != axis: + new_shape.append(item) + else: + new_shape.append(item // num_or_sections) with paddle.static.program_guard(block.program): outs = [ - helper.create_variable_for_type_inference( - dtype=helper.input_dtype()) for i in range(num_or_sections) + block.create_var( + name=paddle.fluid.unique_name.generate_with_ignorable_key( + ".".join([helper.name, 'tmp'])), + dtype=tensor.dtype, + shape=None, + lod_level=tensor.lod_level, + type=tensor.type, + persistable=False, + stop_gradient=False) for i in range(num_or_sections) ] block._insert_op(idx, type="split", @@ -286,9 +489,18 @@ class Inserter: @staticmethod def insert_fill_constant_op(block, idx, op_role): """Insert fill constant op into block at the given index.""" - helper = LayerHelper("fill_constant", **locals()) + # to avoid name conflict with framework + helper = LayerHelper('fill_constant@RESHARD', **locals()) + # use paddle.int64 as dtype with paddle.static.program_guard(block.program): - out = helper.create_variable_for_type_inference(dtype="int32") + out = block.create_var( + name=paddle.fluid.unique_name.generate_with_ignorable_key( + ".".join([helper.name, 'tmp'])), + dtype=paddle.int64, + shape=None, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) inputs = {} attrs = {'force_cpu': False} attrs['str_value'] = str(int("1")) @@ -342,10 +554,18 @@ class Inserter: # insert c_allgather op op_type = 'c_allgather' - helper = LayerHelper(op_type, **locals()) + # to avoid name conflict with framework + helper = LayerHelper(op_type + "@RESHARD", **locals()) with paddle.static.program_guard(block.program): - allgather_out = helper.create_variable_for_type_inference( - dtype=tensor.dtype) + allgather_out = block.create_var( + name=paddle.fluid.unique_name.generate_with_ignorable_key( + ".".join([helper.name, 'tmp'])), + dtype=tensor.dtype, + shape=None, + lod_level=tensor.lod_level, + type=tensor.type, + persistable=False, + stop_gradient=False) block._insert_op(idx + idx_offset, type=op_type, inputs={'X': [tensor]}, @@ -620,12 +840,14 @@ class Resharder: batch_size=None): assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \ "but got {}.".format(type(auto_parallel_main_prog)) - assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_startup_prog should be Program, " \ - "but got {}.".format(type(auto_parallel_startup_prog)) + if auto_parallel_startup_prog is not None: + assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_startup_prog should be Program or None, " \ + "but got {}.".format(type(auto_parallel_startup_prog)) assert isinstance(rank_id, int), "The type of rank_id should be int, " \ "but got {}.".format(type(rank_id)) assert isinstance(dist_context, DistributedContext), "The type of dist_context should be DistributedContext, " \ "but got {}.".format(type(dist_context)) + if batch_size is not None: assert isinstance(batch_size, int), "The type of batch_size should be int, " \ "but got {}.".format(type(batch_size)) @@ -639,6 +861,8 @@ class Resharder: self._has_sent = {} self._has_recv = {} self._has_allgather = {} + # to avoid reshard repeatly + self._has_resharded = {} @property def auto_parallel_main_prog(self): @@ -798,7 +1022,10 @@ class Resharder: for op in sub_block.ops: # skip the input and output of operators inserted in the reshard phase dist_op = dist_context.get_dist_op_for_program(op) - if dist_op: + if dist_op or (op.type == "slice" and not dist_op) or ( + op.type == "split" + and not dist_op) or (op.type == "assign" + and not dist_op): for var_name in op.output_arg_names: if var_name not in sub_block_op_outputs: sub_block_op_outputs.append(var_name) @@ -812,7 +1039,8 @@ class Resharder: while_op = op break - assert while_op is not None + if while_op is None: + continue # find the actual input and output of while op proto = OpProtoHolder.instance().get_op_proto(while_op.type) @@ -821,13 +1049,15 @@ class Resharder: if var_name in sub_block_op_inputs: new_X.append(var_name) assert new_X + new_X.sort() while_op.desc.set_input(proto.inputs[0].name, new_X) new_Out = [] for var_name in while_op.output("Out"): for output_name in sub_block_op_outputs[::-1]: if output_name.find(var_name) != -1: - new_Out.append(output_name) + if output_name not in new_Out: + new_Out.append(output_name) assert new_Out while_op.desc.set_output(proto.outputs[0].name, new_Out) @@ -870,120 +1100,72 @@ class Resharder: return True - def need_reshard(self, - dist_tensor, - dist_op, - actual_process_mesh, - op_input=True): + def need_reshard(self, dist_tensor, dist_attr, op_input=True, dist_op=None): """Judge the tensor whether needs to be resharded.""" is_reshard = False tensor_dist_attr = dist_tensor.dist_attr - tensor_name = dist_tensor.serial_tensor.name tensor_dims_mapping = tensor_dist_attr.dims_mapping tensor_process_mesh = tensor_dist_attr.process_mesh - op_dist_attr = dist_op.dist_attr - op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name) - op_process_mesh = actual_process_mesh + + # dist_attr is [process_mesh, dims_mapping] and process_mesh is not a union + op_process_mesh = dist_attr[0] + if op_input: - op_input_dims_mapping = op_dist_attr.get_input_dims_mapping( - tensor_name) + op_input_dims_mapping = dist_attr[1] if all( - map(lambda x: x is not None, [ + map(lambda x: x, [ tensor_dims_mapping, tensor_process_mesh, op_input_dims_mapping, op_process_mesh ])): - # dims_mapping + # judge whether need reshard by dims_mapping if tensor_dims_mapping != op_input_dims_mapping: - if dist_op.serial_op.type == "while": - sub_block = self.auto_parallel_main_prog.blocks[ - dist_op.serial_op.attr("sub_block").id] - for op in sub_block.ops: - for var_name in op.input_arg_names: - if var_name == tensor_name: - dist_op_attr = self.dist_context.get_dist_op_for_program( - op).dist_attr - var_dims_mapping = dist_op_attr.get_input_dims_mapping( - var_name) - if var_dims_mapping != tensor_dims_mapping: - is_reshard = True - break - else: - is_reshard = True - # process_mesh - if tensor_process_mesh != op_process_mesh: - # when processes length is not the same, the dims mapping must be replicative now - if len(tensor_process_mesh.processes) != len( - op_process_mesh.processes): - assert self.is_unshard(tensor_dims_mapping) - assert self.is_unshard(op_input_dims_mapping) - else: - if dist_tensor.serial_tensor.dtype == paddle.bool: - raise ValueError( - "Bool var is not supported reshard.") - - # for while op, it should find the process mesh of op actually used the tensor as input - if dist_op.serial_op.type == "while": - sub_block = self.auto_parallel_main_prog.blocks[ - dist_op.serial_op.attr("sub_block").id] - for op in sub_block.ops: - for var_name in op.input_arg_names: - if var_name == tensor_name: - dist_op_attr = self.dist_context.get_dist_op_for_program( - op).dist_attr - process_mesh = dist_op_attr.process_mesh - if process_mesh == op_process_mesh: - is_reshard = True - break + if tensor_process_mesh not in self.dist_context.process_meshes: + # assert whether -1 when union. + for item in tensor_dims_mapping: + if item != -1: + raise ValueError( + "The dim must be -1 when tensor process mesh is a union." + ) + # tensor process_mesh: [0, 1, 2, 3], dims_mapping: [-1, -1] + # op process_mesh: [4, 5], dims_mapping: [0, -1] + # reshard is not supported such as above + if not is_reshard: + return is_reshard else: - is_reshard = True + raise ValueError( + "it is not supported that tensor process mesh is a union and needs reshard." + ) + is_reshard = True + + # judge whether need reshard by process_mesh + if tensor_process_mesh != op_process_mesh: + is_reshard = True else: - op_output_dims_mapping = op_dist_attr.get_output_dims_mapping( - tensor_name) + op_output_dims_mapping = dist_attr[1] if all( - map(lambda x: x is not None, [ + map(lambda x: x, [ tensor_dims_mapping, tensor_process_mesh, op_output_dims_mapping, op_process_mesh ])): - if tensor_process_mesh != op_process_mesh: - if dist_tensor.serial_tensor.dtype == paddle.bool: - raise ValueError("Bool var is not supported reshard.") - is_reshard = True if tensor_dims_mapping != op_output_dims_mapping: raise ValueError( "It is not supported that tensor dims mapping is different from op output dims mapping." ) + if tensor_process_mesh != op_process_mesh: + is_reshard = True return is_reshard - def get_process_meshes(self, op): - """Get all process meshes when op has sub block.""" - assert op.has_attr("sub_block") - sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id] - ops = sub_block.ops - op_process_mesh = self.dist_context.get_dist_op_for_program( - op).dist_attr.process_mesh - process_meshes = [] - for op in ops: - dist_op = self.dist_context.get_dist_op_for_program(op) - if not dist_op: - continue - process_mesh = dist_op.dist_attr.process_mesh - if process_mesh not in process_meshes and process_mesh != op_process_mesh: - process_meshes.append(process_mesh) - - if not process_meshes: - process_meshes.append(op_process_mesh) - - return process_meshes - def get_op_process_meshes(self, op): + """Get sub process meshes of the given op if op process mesh is a union.""" process_meshes = [] dist_op = self.dist_context.get_dist_op_for_program(op) op_process_mesh = dist_op.dist_attr.process_mesh + for process_mesh in self.dist_context.process_meshes: if set(process_mesh.processes) & (set( op_process_mesh.processes)) and len( - process_mesh.processes) <= len( + process_mesh.processes) < len( op_process_mesh.processes): process_meshes.append(process_mesh) @@ -993,39 +1175,14 @@ class Resharder: return process_meshes - def get_while_op_actual_process_mesh(self, op): - """Get the while op actual Process mesh corresponding to rank""" - assert op.type == "while" - while_op_process_mesh = self.dist_context.get_dist_op_for_program( - op).dist_attr.process_mesh - sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id] - ops = sub_block.ops - actual_process_mesh = None - for op in ops: - dist_op = self.dist_context.get_dist_op_for_program(op) - if not dist_op: - continue - process_mesh = dist_op.dist_attr.process_mesh - if process_mesh == while_op_process_mesh: - continue - if self.rank_id in process_mesh.processes: - raw_process_mesh = process_mesh - break - - if actual_process_mesh is None and self.rank_id in while_op_process_mesh.processes: - actual_process_mesh = while_op_process_mesh - - assert actual_process_mesh is not None - return actual_process_mesh - - def find_op_desc_seq(self, dist_tensor, dist_op, actual_process_mesh): + def find_op_desc_seq(self, dist_tensor, dist_attr, serial=False): """ Find the op description sequence to reshard the source tensor for matching the op requirement. Args: dist_tensor (DistributedTensor): A distributed tensor. - dist_op (DistributedOperator): A distributed operator. - actual_process_mesh (ProcessMesh): The actual op process mesh. + dist_attr (list): A list contains process_mesh and dims_mapping such as [process_mesh, dims_mapping]. + serial (bool): If serial is true, the dist tensor and dist op come from serial program. Otherwise, they come from auto program. Returns: Dict, the dict represents the required op description sequence corresponding to process, The key of dict is @@ -1034,24 +1191,26 @@ class Resharder: tensor_dist_attr = dist_tensor.dist_attr source_tensor = dist_tensor.serial_tensor tensor_name = source_tensor.name + source_dims_mapping = tensor_dist_attr.dims_mapping source_process_mesh = tensor_dist_attr.process_mesh source_process_group = source_process_mesh.processes source_process_shape = source_process_mesh.topology - op_dist_attr = dist_op.dist_attr - target_process_mesh = actual_process_mesh - target_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name) + target_process_mesh = dist_attr[0] + target_dims_mapping = dist_attr[1] target_process_group = target_process_mesh.processes target_process_shape = target_process_mesh.topology if source_tensor.shape[0] < 0: + assert source_tensor.shape[0] == -1 new_shape = list(source_tensor.shape) new_shape[0] = self.batch_size source_tensor.desc.set_shape(new_shape) complete_shape = Resharder.compute_complete_shape( - source_tensor.shape, source_process_shape, source_dims_mapping) + source_tensor.shape, source_process_shape, + source_dims_mapping) if not serial else source_tensor.shape op_desc_seq = {} # TODO: if the target process group has the same process with source process group @@ -1060,13 +1219,14 @@ class Resharder: set(source_process_group)): pass - # in the different process group, it will use send, recv, concat and slice op elif target_process_group != source_process_group: partition_process_mapping_list = [] for source_process in source_process_group: + # get partition index of source process source_partition_index = Resharder.compute_partition_index(source_process, complete_shape, source_dims_mapping, \ source_process_shape, source_process_group) if not partition_process_mapping_list: + # the item in partition_process_mapping_list is source_partition_index, which processes and whether has been used partition_process_mapping_list.append( [source_partition_index, [source_process], [False]]) else: @@ -1076,6 +1236,7 @@ class Resharder: [item[1] for item in partition_process_mapping_list]) has_used = list( [item[2] for item in partition_process_mapping_list]) + if partition_list.count(source_partition_index) == 1: index = partition_list.index(source_partition_index) process_list[index].append(source_process) @@ -1085,6 +1246,7 @@ class Resharder: [source_partition_index, [source_process], [False]]) for target_process in target_process_group: + # has_sent means the source_partition_index has been sent to target_process has_sent = [] target_partition_index = Resharder.compute_partition_index( target_process, complete_shape, target_dims_mapping, @@ -1114,6 +1276,7 @@ class Resharder: has_used[i] = True break i += 1 + if i == len(has_used): has_used = list(map(lambda x: False, has_used)) to_send_process = process_list[0] @@ -1127,10 +1290,16 @@ class Resharder: all_partition_index_list.append(source_partition_index) # append send and recv op desc + is_bool = ( + dist_tensor.serial_tensor.dtype == paddle.bool) send_op_desc = SendOpDesc(source_partition_index, - target_process) + to_send_process, + target_process, + is_bool=is_bool) recv_op_desc = RecvOpDesc(source_partition_index, - to_send_process) + to_send_process, + target_process, + is_bool=is_bool) op_desc_seq[to_send_process].append(send_op_desc) op_desc_seq[target_process].append(recv_op_desc) has_sent.append(source_partition_index) @@ -1146,16 +1315,24 @@ class Resharder: slice_ends = [] slices_axes = [] concatenated_partition_index = partition_index_list[0] + to_slice_tensor_shape = [] + for idx, item in enumerate(concatenated_partition_index): slice_starts.append(target_partition_index[idx][0] - item[0]) slice_ends.append(target_partition_index[idx][1] - item[0]) slices_axes.append(idx) + to_slice_tensor_shape.append(item[1] - item[0]) + op_desc_seq[target_process].append( - SliceOpDesc(slice_starts, slice_ends, slices_axes)) + SliceOpDesc(slice_starts, + slice_ends, + slices_axes, + shape=to_slice_tensor_shape)) - # in the same process group, it will use allgahther and slice op + # in the same process group, it will use allgahther and slice op. else: + # NOTE: It just supports even partition scene. partition_index_list = [] all_partition_index_list = [] process_index = [] @@ -1191,17 +1368,21 @@ class Resharder: slice_ends.append(item[1]) slices_axes.append(idx) + to_slice_tensor_shape = dist_tensor.global_sizes() slice_op_desc = SliceOpDesc(starts=slice_starts, ends=slice_ends, - axes=slices_axes) - op_desc_seq[process] = [AllGatherOpDesc(group=group), + axes=slices_axes, + shape=to_slice_tensor_shape) + allgather_shape = None if not serial else dist_tensor.local_sizes( + rank=process) + op_desc_seq[process] = [AllGatherOpDesc(group=group, shape=allgather_shape, is_bool=(source_tensor.dtype == paddle.bool)), ConcatOpDesc(partition_index_list=all_partition_index_list), slice_op_desc] \ if len(group) > 1 else [slice_op_desc] return op_desc_seq def parse_op_desc(self, block, op_desc_seq, var_name, reshard_op, - actual_process_mesh): + dist_attr): """Parse op desc sequence and insert op in the block""" tensor_list = [] partition_tensor_list = [] @@ -1226,13 +1407,32 @@ class Resharder: self.has_allgather[var_name] = [] if not self.has_allgather[var_name] or op_desc.group not in list( map(lambda x: x[0], self.has_allgather[var_name])): - tensor_list, idx_offset = Inserter.insert_allgather_op( - block, idx, source_tensor, op_desc.group, - reshard_op.attr('op_role')) - idx += idx_offset - tensor_name_list = [var.name for var in tensor_list] - self.has_allgather[var_name].append( - [op_desc.group, tensor_name_list]) + if op_desc.is_bool: + # for bool data allgather, cast to int64 -> allgather -> cast bool + out_cast = Inserter.insert_cast_op( + block, idx, source_tensor, + reshard_op.attr('op_role'), paddle.int64) + tensor_list, idx_offset = Inserter.insert_allgather_op( + block, idx + 1, out_cast, op_desc.group, + reshard_op.attr('op_role')) + idx += idx_offset + tensor_name_list = [] + for var in tensor_list: + out_cast = Inserter.insert_cast_op( + block, idx, var, reshard_op.attr('op_role'), + paddle.bool) + tensor_name_list.append(out_cast.name) + idx += 1 + self.has_allgather[var_name].append( + [op_desc.group, tensor_name_list]) + else: + tensor_list, idx_offset = Inserter.insert_allgather_op( + block, idx, source_tensor, op_desc.group, + reshard_op.attr('op_role')) + idx += idx_offset + tensor_name_list = [var.name for var in tensor_list] + self.has_allgather[var_name].append( + [op_desc.group, tensor_name_list]) else: for item in self.has_allgather[var_name]: if op_desc.group == item[0]: @@ -1249,10 +1449,19 @@ class Resharder: if var_name not in self.has_sent.keys(): self.has_sent[var_name] = [] if op_desc.dst not in self.has_sent[var_name]: - Inserter.insert_send_op(block, idx, source_tensor, - op_desc.dst, - reshard_op.attr('op_role')) - idx += 1 + if op_desc.is_bool: + out_cast = Inserter.insert_cast_op( + block, idx, source_tensor, + reshard_op.attr('op_role'), paddle.int64) + Inserter.insert_send_op(block, idx + 1, out_cast, + op_desc.src, op_desc.dst, + reshard_op.attr('op_role')) + idx += 2 + else: + Inserter.insert_send_op(block, idx, source_tensor, + op_desc.src, op_desc.dst, + reshard_op.attr('op_role')) + idx += 1 self.has_sent[var_name].append(op_desc.dst) elif isinstance(op_desc, RecvOpDesc): @@ -1263,17 +1472,58 @@ class Resharder: shape = [] for index in partition_index: shape.append(index[1] - index[0]) - recv_tensor = block.create_var( - name=unique_name.generate(var_name + "@recv"), - shape=shape, - dtype=source_tensor.dtype, - type=source_tensor.type) - Inserter.insert_recv_op(block, idx, recv_tensor, - op_desc.src, - reshard_op.attr('op_role')) - tensor_list.append(recv_tensor) - idx += 1 - self.has_recv[var_name][op_desc.src] = recv_tensor + if op_desc.is_bool: + # for bool data, recv int64 -> cast to bool + recv_tensor = block.create_var( + name=unique_name.generate(var_name + "@recv"), + shape=shape, + lod_level=source_tensor.lod_level, + dtype=paddle.int64, + type=source_tensor.type) + Inserter.insert_recv_op(block, idx, recv_tensor, + op_desc.src, op_desc.dst, + reshard_op.attr('op_role')) + out_cast = Inserter.insert_cast_op( + block, idx + 1, recv_tensor, + reshard_op.attr('op_role'), paddle.bool) + tensor_list.append(out_cast) + idx += 2 + self.has_recv[var_name][op_desc.src] = out_cast + else: + recv_tensor = block.create_var( + name=unique_name.generate(var_name + "@recv"), + shape=shape, + lod_level=source_tensor.lod_level, + dtype=source_tensor.dtype, + type=source_tensor.type) + Inserter.insert_recv_op(block, idx, recv_tensor, + op_desc.src, op_desc.dst, + reshard_op.attr('op_role')) + + # for lod tensor, need reset lod after received + if recv_tensor.lod_level != 0: + set_lod = False + # use data lod to reset tensor lod + for tmp_block in self.auto_parallel_main_prog.blocks: + for tmp_var_name in tmp_block.vars: + tmp_var = tmp_block.vars[tmp_var_name] + if tmp_var.is_data and tmp_var.lod_level == recv_tensor.lod_level: + reset_lod_out = Inserter.insert_reset_lod_op( + block, idx + 1, recv_tensor, + tmp_var, reshard_op.attr('op_role')) + tensor_list.append(reset_lod_out) + idx += 2 + self.has_recv[var_name][ + op_desc.src] = reset_lod_out + set_lod = True + break + if set_lod: + break + assert set_lod is True + else: + tensor_list.append(recv_tensor) + idx += 1 + self.has_recv[var_name][op_desc.src] = recv_tensor else: tensor_list.append(self.has_recv[var_name][op_desc.src]) @@ -1303,188 +1553,506 @@ class Resharder: new_var_name=new_name, op_role=reshard_op.attr('op_role')) + process_mesh = dist_attr[0] + dims_mapping = dist_attr[1] + tensor_attr = TensorDistributedAttribute() - process_mesh = actual_process_mesh - dims_mapping = self.dist_context.get_op_dist_attr_for_program( - matched_op).get_input_dims_mapping(var_name) tensor_attr.dims_mapping = dims_mapping tensor_attr.process_mesh = process_mesh self.dist_context.set_tensor_dist_attr_for_program( target_tensor, tensor_attr) - if op.type == "while": + if matched_op.type == "while": # var_reshard_mapping means the while op input need be changed to if "var_reshard_mapping" not in Resharder.while_block_info[ op.attr("sub_block").id].keys(): Resharder.while_block_info[op.attr( "sub_block").id]["var_reshard_mapping"] = {} + if var_name not in Resharder.while_block_info[op.attr( + "sub_block").id]["var_reshard_mapping"].keys(): + Resharder.while_block_info[op.attr("sub_block").id][ + "var_reshard_mapping"][var_name] = [] Resharder.while_block_info[op.attr("sub_block").id][ - "var_reshard_mapping"][var_name] = target_tensor.name + "var_reshard_mapping"][var_name].append( + [dist_attr, target_tensor.name]) # rename op input name according to new name for op in block.ops: + # just for while op + while_op_X_append = [] for name in op.input_arg_names: op_dist_attr = self.dist_context.get_op_dist_attr_for_program( op) if name == var_name and op_dist_attr is not None: if op.desc.id() == matched_op.desc.id(): - op.desc._rename_input(name, target_tensor.name) - op_dist_attr.set_input_dims_mapping( - target_tensor.name, dims_mapping) - op_dist_attr.set_input_dist_attr(name, None) - continue + if matched_op.type == "while": + old_name = name + new_name = target_tensor.name + assert old_name != new_name + op_input_dist_attr = op_dist_attr.get_input_dist_attr( + old_name) + op_dist_attr.set_input_dist_attr( + new_name, op_input_dist_attr) + op_dist_attr.set_input_dims_mapping( + new_name, dims_mapping) + if old_name in op_dist_attr._inputs_dist_attrs: + op_dist_attr.del_input_dist_attr( + old_name) + while_op_X_append.append(new_name) + continue + else: + op.desc._rename_input( + name, target_tensor.name) + old_name = name + new_name = target_tensor.name + assert old_name != new_name + op_input_dist_attr = op_dist_attr.get_input_dist_attr( + old_name) + op_dist_attr.set_input_dist_attr( + new_name, op_input_dist_attr) + op_dist_attr.set_input_dims_mapping( + new_name, dims_mapping) + op_dist_attr.del_input_dist_attr(old_name) + continue - # NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation. op_process_mesh = op_dist_attr.process_mesh op_input_dims_mapping = op_dist_attr.get_input_dims_mapping( var_name) + # NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation. if op_process_mesh == process_mesh and op_input_dims_mapping == dims_mapping: op.desc._rename_input(name, target_tensor.name) + old_name = name + new_name = target_tensor.name + assert old_name != new_name + op_input_dist_attr = op_dist_attr.get_input_dist_attr( + old_name) + op_dist_attr.set_input_dist_attr( + new_name, op_input_dist_attr) op_dist_attr.set_input_dims_mapping( - target_tensor.name, dims_mapping) - op_dist_attr.set_input_dist_attr(name, None) + new_name, dims_mapping) + op_dist_attr.del_input_dist_attr(old_name) - def reshard(self): - for block_idx, block in enumerate(self.auto_parallel_main_prog.blocks): - if block_idx in Resharder.while_block_info: - if "var_reshard_mapping" in Resharder.while_block_info[ - block_idx]: - var_reshard_mapping = Resharder.while_block_info[block_idx][ - "var_reshard_mapping"] - for op in block.ops: - for var_name in op.input_arg_names: - if var_name in var_reshard_mapping: - op.desc._rename_input( - var_name, var_reshard_mapping[var_name]) - dist_op = self.dist_context.get_dist_op_for_program( - op) - op_dist_attr = dist_op.dist_attr - if op_dist_attr.process_mesh == Resharder.while_block_info[ - block_idx]["actual_process_mesh"]: - dims_mapping = op_dist_attr.get_input_dims_mapping( - var_name) - op_dist_attr.set_input_dims_mapping( - var_reshard_mapping[var_name], - dims_mapping) - op_dist_attr.set_input_dist_attr( - var_name, None) - - # the outputs also need to be renamed when the output name is the same with input name - for var_name in op.output_arg_names: - if var_name in var_reshard_mapping: - op.desc._rename_output( - var_name, var_reshard_mapping[var_name]) - dist_op = self.dist_context.get_dist_op_for_program( - op) - op_dist_attr = dist_op.dist_attr - if op_dist_attr.process_mesh == Resharder.while_block_info[ - block_idx]["actual_process_mesh"]: - dims_mapping = op_dist_attr.get_output_dims_mapping( - var_name) - op_dist_attr.set_output_dims_mapping( - var_reshard_mapping[var_name], - dims_mapping) - op_dist_attr.set_output_dist_attr( - var_name, None) - - idx = 0 - while idx < len(block.ops): - pre_op_count = len(block.ops) - op = block.ops[idx] - - if self.is_special_op(op): - idx += 1 - continue + # for while op, the input X should reset + if while_op_X_append: + proto = OpProtoHolder.instance().get_op_proto(op.type) + op.desc.set_input(proto.inputs[0].name, + op.input("X") + while_op_X_append) + + def _get_while_op_input_attrs(self, op, var_name): + # NOTE: Multi while loop is not supported + assert op.type == "while" + sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id] + ops = sub_block.ops + input_attrs = [] + + for op in ops: + dist_op = self.dist_context.get_dist_op_for_program(op) + if not dist_op: + continue + dist_attr = dist_op.dist_attr + for name in op.input_arg_names: + if name == var_name: + process_mesh = dist_attr.process_mesh + input_dims_mapping = dist_attr.get_input_dims_mapping( + var_name) + has_exist = False + for input_attr in input_attrs: + if process_mesh == input_attr[ + 0] and input_dims_mapping == input_attr[1]: + has_exist = True + break + if not has_exist: + input_attrs.append([process_mesh, input_dims_mapping]) + return input_attrs + + def _get_common_op_input_attrs(self, op, var_name): + process_meshes = [] + dist_op = self.dist_context.get_dist_op_for_program(op) + dist_attr = dist_op.dist_attr + op_process_mesh = dist_attr.process_mesh + for process_mesh in self.dist_context.process_meshes: + if set(process_mesh.processes) & (set( + op_process_mesh.processes)) and len( + process_mesh.processes) < len( + op_process_mesh.processes): + process_meshes.append(process_mesh) + + # it means that the process mesh is not a union when process meshes is none + if not process_meshes: + process_meshes.append(op_process_mesh) + + input_dims_mapping = dist_attr.get_input_dims_mapping(var_name) + input_attrs = [] + for process_mesh in process_meshes: + input_attrs.append([process_mesh, input_dims_mapping]) + + return input_attrs + + def get_op_input_attrs(self, op, var_name): + op_input_attrs = [] - dist_op = self.dist_context.get_dist_op_for_program(op) - if dist_op is not None: - process_meshes = [] - if op.type == "while": - if not self.is_condition_replicative(op): + if op.type == "while": + op_input_attrs = self._get_while_op_input_attrs(op, var_name) + else: + op_input_attrs = self._get_common_op_input_attrs(op, var_name) + + assert op_input_attrs + + return op_input_attrs + + def _remove_global_process_mesh(self): + """Remove global process mesh from dist_context.process_meshes""" + processes = set() + process_mesh_count = len(self.dist_context.process_meshes) + if process_mesh_count > 1: + global_process_mesh_idx = None + for process_mesh in self.dist_context.process_meshes: + for process in process_mesh.processes: + processes.add(process) + for idx, process_mesh in enumerate( + self.dist_context.process_meshes): + if len(set(process_mesh.processes)) == len(processes): + global_process_mesh_idx = idx + break + if global_process_mesh_idx is not None: + self.dist_context.process_meshes.pop(idx) + + def _change_subblock_op_input_and_output(self, block_idx, block): + if "var_reshard_mapping" in Resharder.while_block_info[block_idx]: + var_reshard_mapping = Resharder.while_block_info[block_idx][ + "var_reshard_mapping"] + for op in block.ops: + for var_name in op.input_arg_names: + if var_name in var_reshard_mapping: + # in while sub block, the union process mesh is not split before reshard sub block + dist_op = self.dist_context.get_dist_op_for_program(op) + dist_attr = dist_op.dist_attr + target_name = None + for item in var_reshard_mapping[var_name]: + if dist_attr.process_mesh == item[0][ + 0] and dist_attr.get_input_dims_mapping( + var_name) == item[0][1]: + target_name = item[1] + break + if target_name is None: + continue + else: + op.desc._rename_input(var_name, target_name) + dist_op = self.dist_context.get_dist_op_for_program( + op) + op_dist_attr = dist_op.dist_attr + old_name = var_name + new_name = target_name + assert old_name != new_name + op_input_dist_attr = op_dist_attr.get_input_dist_attr( + old_name) + op_dist_attr.set_input_dist_attr( + new_name, op_input_dist_attr) + op_dist_attr.del_input_dist_attr(old_name) + + # the outputs also need to be renamed when the output name is the same with input name in inplace op + for var_name in op.output_arg_names: + # if the tensor has been resharded multiply, it is not supported now. + if var_name in var_reshard_mapping: + if len(var_reshard_mapping[var_name]) > 1: raise ValueError( - "Please check the condition due to the dims mapping is not replicative." + "The scene is not supported that the output is inplaced and the tensor has been resharded multiply when as input." ) - process_meshes = self.get_process_meshes(op) - assert process_meshes - if op.attr("sub_block" - ).id not in Resharder.while_block_info: - Resharder.while_block_info[op.attr( - "sub_block").id] = {} - Resharder.while_block_info[op.attr( - "sub_block").id]["op_id"] = op.desc.id() - Resharder.while_block_info[op.attr("sub_block").id][ - "actual_process_mesh"] = self.get_while_op_actual_process_mesh( - op) - else: - process_meshes = self.get_op_process_meshes(op) - input_vars = None - if op.type == "while": - input_var_names = op.input("X") - else: - input_var_names = op.input_arg_names - idx_offset = 0 - for var_name in op.input_arg_names: - # skip lod_tensor_blocking_queue_0 - if var_name == "lod_tensor_blocking_queue_0": - continue - var = get_var_with_recursion( - var_name, block, self.auto_parallel_main_prog) - dist_tensor = self.dist_context.get_dist_tensor_for_program( - var) - for process_mesh in process_meshes: - if dist_tensor is not None and self.need_reshard( - dist_tensor, dist_op, process_mesh): - reshard_op_desc = self.find_op_desc_seq( - dist_tensor, dist_op, process_mesh) - self.parse_op_desc(block, reshard_op_desc, - var_name, op, process_mesh) - cur_op_count = len(block.ops) - idx_offset = idx_offset + cur_op_count - pre_op_count - pre_op_count = cur_op_count - idx = idx + idx_offset + 1 + target_name = var_reshard_mapping[var_name][0][1] + + op.desc._rename_output(var_name, target_name) + dist_op = self.dist_context.get_dist_op_for_program(op) + op_dist_attr = dist_op.dist_attr + old_name = var_name + new_name = target_name + assert old_name != new_name + op_output_dist_attr = op_dist_attr.get_output_dist_attr( + old_name) + op_dist_attr.set_output_dist_attr( + new_name, op_output_dist_attr) + op_dist_attr.del_output_dist_attr(old_name) + + def _reshard_input(self, block): + idx = 0 + while idx < len(block.ops): + pre_op_count = len(block.ops) + op = block.ops[idx] + + if self.is_special_op(op): + idx += 1 + continue + + dist_op = self.dist_context.get_dist_op_for_program(op) + if dist_op is not None: + op_input_dist_attrs = [ + ] # [(op_process_mesh, op_input_dims_mapping), (op_process_mesh, op_input_dims_mapping)] + if op.type == "while": + if not self.is_condition_replicative(op): + raise ValueError( + "Please check the condition due to the dims mapping is not replicative." + ) + if op.attr( + "sub_block").id not in Resharder.while_block_info: + Resharder.while_block_info[op.attr("sub_block").id] = {} + Resharder.while_block_info[op.attr( + "sub_block").id]["op_id"] = op.desc.id() + + if op.type == "while": + # condition var process mesh is the same with op and dims_mapping is replicative, so it do not need reshard + input_var_names = op.input("X") else: - idx += 1 + input_var_names = op.input_arg_names + # to avoid while op X order different + input_var_names.sort() + + idx_offset = 0 + for var_name in input_var_names: + # skip lod_tensor_blocking_queue_0 + if var_name == "lod_tensor_blocking_queue_0": + continue + var = get_var_with_recursion(var_name, block, + self.auto_parallel_main_prog) + dist_tensor = self.dist_context.get_dist_tensor_for_program( + var) + + # judge whether union tensor dims_mapping all -1 + is_union_process_mesh_tensor = False + if dist_tensor.dist_attr.process_mesh not in self.dist_context.process_meshes and self.dist_context.process_meshes: + is_union_process_mesh_tensor = True + assert dist_tensor.dist_attr.dims_mapping.count( + -1) == len(dist_tensor.dist_attr.dims_mapping) + + op_input_attrs = self.get_op_input_attrs(op, var_name) + for input_attr in op_input_attrs: + input_process_mesh = None + + # deal with union tensor + if is_union_process_mesh_tensor: + # if op process mesh is subset of union tensor process mesh, need no reshard + if set(input_attr[0].processes) <= set( + dist_tensor.dist_attr.process_mesh.processes + ): + continue - # insert send and recv op if output process mesh is different from tensor process mesh - idx = 0 - # skip reader and ops whose process mesh is union - skip_ops = [ - "create_py_reader", "create_double_buffer_reader", "read", - "while", "write_to_array", "read_from_array" - ] - global _g_special_ops - skip_ops += _g_special_ops - while idx < len(block.ops): - pre_op_count = len(block.ops) - op = block.ops[idx] - dist_op = self.dist_context.get_dist_op_for_program(op) - if dist_op is not None and op.type not in skip_ops: - for var_name in op.output_arg_names: - var = get_var_with_recursion( - var_name, block, self.auto_parallel_main_prog) - dist_tensor = self.dist_context.get_dist_tensor_for_program( - var) - process_mesh = dist_op.dist_attr.process_mesh if dist_tensor is not None and self.need_reshard( - dist_tensor, dist_op, process_mesh, False): - for index, item in enumerate( - dist_op.dist_attr.process_mesh.processes): - recv_rank = dist_tensor.dist_attr.process_mesh.processes[ - index] - if self.rank_id == item: - Inserter.insert_send_op( - block, idx + 1, var, recv_rank, - op.attr('op_role')) - if self.rank_id == recv_rank: - Inserter.insert_recv_op( - block, idx + 1, var, item, - op.attr('op_role')) + dist_tensor, input_attr): + reshard_op_desc = self.find_op_desc_seq( + dist_tensor, input_attr) + self.parse_op_desc(block, reshard_op_desc, var_name, + op, input_attr) cur_op_count = len(block.ops) idx_offset = idx_offset + cur_op_count - pre_op_count pre_op_count = cur_op_count - idx = idx + idx_offset + 1 + idx = idx + idx_offset + 1 + else: + idx += 1 + + def _hadnle_recv(self, block, idx, var, op, send_rank, recv_rank): + if self.rank_id == recv_rank: + # if recv bool data, recv then cast + if var.dtype == paddle.bool: + recv_cast_out = block.create_var( + name=unique_name.generate(var.name + "@recv"), + shape=var.shape, + lod_level=var.lod_level, + dtype=paddle.int64, + type=var.type) + Inserter.insert_recv_op(block, idx + 1, + recv_cast_out, send_rank, recv_rank, + op.attr('op_role')) + reset_lod_out = None + if var.lod_level != 0: + set_lod = False + for tmp_block in self.auto_parallel_main_prog.blocks: + for tmp_var_name in tmp_block.vars: + tmp_var = tmp_block.vars[tmp_var_name] + if tmp_var.is_data and tmp_var.lod_level == var.lod_level: + reset_lod_out = block.create_var( + name=unique_name.generate(var.name + + "@RESETLOD"), + shape=recv_cast_out.shape, + type=recv_cast_out.type, + dtype=recv_cast_out.dtype, + lod_level=recv_cast_out.lod_level) + idx += 1 + block._insert_op( + idx, + type="lod_reset", + inputs={ + 'X': recv_cast_out, + 'Y': tmp_var + }, + outputs={'Out': reset_lod_out}, + attrs={'op_role': op.attr("op_role")}) + set_lod = True + break + if set_lod: + break + assert set_lod is True + + # cast int64 to bool + block._insert_op(idx + 2, + type='cast', + inputs={ + 'X': [recv_cast_out] if + reset_lod_out is None else [reset_lod_out] + }, + outputs={'Out': [var]}, + attrs={ + 'in_dtype': recv_cast_out.dtype, + 'out_dtype': var.dtype, + 'op_role': op.attr('op_role') + }) + else: + if var.lod_level != 0: + recv_out = block.create_var( + name=unique_name.generate(var.name + "@recv"), + shape=var.shape, + lod_level=var.lod_level, + dtype=var.int64, + type=var.type) + Inserter.insert_recv_op(block, idx + 1, recv_out, send_rank, + recv_rank, op.attr('op_role')) + set_lod = False + for tmp_block in self.auto_parallel_main_prog.blocks: + for tmp_var_name in tmp_block.vars: + tmp_var = tmp_block.vars[tmp_var_name] + if tmp_var.is_data and tmp_var.lod_level == var.lod_level: + idx += 1 + block._insert_op( + idx, + type="lod_reset", + inputs={ + 'X': recv_out, + 'Y': tmp_var + }, + outputs={'Out': var}, + attrs={'op_role': op.attr("op_role")}) + set_lod = True + break + if set_lod: + break + assert set_lod is True else: - idx += 1 + Inserter.insert_recv_op(block, idx + 1, var, send_rank, + recv_rank, op.attr('op_role')) + + def _handle_send(self, block, idx, var, op, send_rank, recv_rank): + if var.dtype == paddle.bool: + cast_out = Inserter.insert_cast_op(block, idx + 1, var, + op.attr('op_role'), paddle.int64) + Inserter.insert_send_op(block, idx + 2, cast_out, send_rank, + recv_rank, op.attr('op_role')) + else: + Inserter.insert_send_op(block, idx + 1, var, send_rank, recv_rank, + op.attr('op_role')) + + def _reshard_output(self, block): + # insert send and recv op if output process mesh is different from tensor process mesh + idx = 0 + # skip reader and ops whose process mesh is union + skip_ops = [ + "create_py_reader", "create_double_buffer_reader", "read", "while", + "write_to_array", "read_from_array" + ] + global _g_special_ops + skip_ops += _g_special_ops + while idx < len(block.ops): + pre_op_count = len(block.ops) + op = block.ops[idx] + dist_op = self.dist_context.get_dist_op_for_program(op) + if dist_op is not None and op.type not in skip_ops: + idx_offset = 0 + for var_name in op.output_arg_names: + var = get_var_with_recursion(var_name, block, + self.auto_parallel_main_prog) + dist_tensor = self.dist_context.get_dist_tensor_for_program( + var) + tensor_process_mesh = dist_tensor.dist_attr.process_mesh + output_attr = [ + dist_op.dist_attr.process_mesh, + dist_op.dist_attr.get_output_dims_mapping(var_name) + ] + if dist_tensor is not None and self.need_reshard( + dist_tensor, output_attr, False): + tensor_processes = set( + tensor_process_mesh.processes) - ( + set(tensor_process_mesh.processes) + & set(output_attr[0].processes)) + if tensor_processes: + if len(tensor_processes) != len( + output_attr[0].processes): + if dist_tensor.dist_attr.dims_mapping.count( + -1) != len( + dist_tensor.dist_attr.dims_mapping + ) or output_attr[1].count(-1) != len( + output_attr[1]): + raise ValueError( + "The dims_mapping must be -1") + else: + for index, tensor_process in enumerate( + tensor_processes): + recv_rank = tensor_process + actual_index = index + if index >= len( + output_attr[0].processes): + actual_index = ( + index - + len(output_attr[0].processes) + ) % len(output_attr[0].processes) + item = output_attr[0].processes[ + actual_index] + if recv_rank == item: + continue + if self.rank_id == item: + # if send bool data, cast then send + self._handle_send( + block, idx, var, op, item, + recv_rank) + if self.rank_id == recv_rank: + # if recv bool data, recv then cast + self._hadnle_recv( + block, idx, var, op, item, + recv_rank) + else: + for index, tensor_process in enumerate( + tensor_processes): + recv_rank = tensor_process + item = output_attr[0].processes[index] + if recv_rank == item: + continue + if self.rank_id == item: + # if send bool data, cast then send + self._handle_send( + block, idx, var, op, item, + recv_rank) + if self.rank_id == recv_rank: + # if recv bool data, recv then cast + self._hadnle_recv( + block, idx, var, op, item, + recv_rank) + + cur_op_count = len(block.ops) + idx_offset = idx_offset + cur_op_count - pre_op_count + pre_op_count = cur_op_count + + idx = idx + idx_offset + 1 + else: + idx += 1 + + def reshard(self): + self._remove_global_process_mesh() + for block_idx, block in enumerate(self.auto_parallel_main_prog.blocks): + # change the var_name before resharding sub block + if block_idx in Resharder.while_block_info: + self._change_subblock_op_input_and_output(block_idx, block) + + # reshard input + self._reshard_input(block) + + # reshard output + # NOTE: Only support that insert send and recv op if output process mesh is different from tensor process mesh + self._reshard_output(block) # remove no need vars and ops in the main program Remover.remove_no_need_in_main(self.auto_parallel_main_prog, diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index d6fd06647ba..46e77eb50ca 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1419,7 +1419,10 @@ def get_standalone_cost_data(distributed_programs): } standalone_cost_data = [] - not_enum_ops = ["create_py_reader", "create_double_buffer_reader", "read"] + # skip ops + not_enum_ops = [ + "create_py_reader", "create_double_buffer_reader", "read", "assign" + ] for distributed_program in distributed_programs: cost_data = {} vars = distributed_program.global_block().vars diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index c6a8f11574f..6e07e16e97f 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -27,7 +27,10 @@ from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_di OpRole = core.op_proto_and_checker_maker.OpRole OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() -_skip_ops = ['create_py_reader', 'create_double_buffer_reader', 'read', 'slice'] +_skip_ops = [ + 'create_py_reader', 'create_double_buffer_reader', 'read', 'slice', 'split', + 'assign' +] # update here to support new optimizers _supported_optimizer_type = [ "adam", "adamax", "adamw", "decayed_adagrad", "momentum", "dgc_momentum", diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_data_unshard.py b/python/paddle/fluid/tests/unittests/auto_parallel_data_unshard.py index d3a4a4898bf..9d2b2739401 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel_data_unshard.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_data_unshard.py @@ -107,7 +107,9 @@ class TestDataUnshard(unittest.TestCase): input_data = np.array(range(2 * 8)).reshape([2, 8]).astype("float32") label_data = np.random.randint(0, 10, [2, 8]).astype("float32") - fetchs = [loss.name, 'input@RESHARD_0'] + fetchs = [loss.name, 'split@RESHARD.tmp_0'] if worker_index == 0 else [ + loss.name, 'split@RESHARD.tmp_1' + ] loss_np, shard_data_np = exe.run(distributed_main_program, feed={ "input": input_data, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py index 93c5ded0c10..4df770df696 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -28,7 +28,7 @@ from paddle.distributed import fleet from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.reshard import Resharder -from paddle.distributed.auto_parallel.process_group import _g_process_group_map +from paddle.distributed.auto_parallel.process_group import _g_process_group_map, ProcessGroup from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr paddle.enable_static() @@ -307,6 +307,7 @@ class TestMLPReshard(unittest.TestCase): train_program, startup_program, dist_context, rank_id) for key in list(_g_process_group_map.keys()): del _g_process_group_map[key] + _g_process_group_map[0] = ProcessGroup(0, []) resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() @@ -326,10 +327,10 @@ class TestMLPReshard(unittest.TestCase): train_program, startup_program, dist_context, rank_id, True) for key in list(_g_process_group_map.keys()): del _g_process_group_map[key] + _g_process_group_map[0] = ProcessGroup(0, []) resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() - # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_initialization(dist_startup_prog, rank_id)) -- GitLab