diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index 44521c05994ddb3bd38680ab15dcb33bf27c82d9..2dd85df800a47abfe111e25ba0b5178cc892b14d 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -92,6 +92,42 @@ class AllGatherOpDesc: return f"op: {self._desc}, group: {self._group}, shape: {self._shape}, is_bool: {self._is_bool}." +class AllGatherConcatOpDesc: + """ + Describe the c_concat op in the reshard phase. + + Args: + group (list): Process group. + shape (list): The tensor shape. + is_bool (bool): Whether c_concat bool data. Default: False. + """ + + def __init__(self, group, shape, is_bool=False): + self._group = group + self._desc = "c_concat" + self._shape = shape + self._is_bool = is_bool + + @property + def is_bool(self): + return self._is_bool + + @property + def group(self): + return self._group + + @property + def desc(self): + return self._desc + + @property + def shape(self): + return self._shape + + def __repr__(self): + return f"op: {self._desc}, group: {self._group}, shape: {self._shape}, is_bool: {self._is_bool}." + + class SendOpDesc: """ Describe the send op in the reshard phase. @@ -640,6 +676,46 @@ class Inserter: tensor_list.extend(split_out) return tensor_list, idx_offset + @staticmethod + def insert_c_concat_op(block, idx, tensor, ranks, op_role): + """Insert c_concat op into block at the given index.""" + group = new_process_group(ranks) + idx_offset = 0 + + # insert c_concat op + op_type = 'c_concat' + # to avoid name conflict with framework + helper = LayerHelper(op_type + "@RESHARD", **locals()) + with paddle.static.program_guard(block.program): + c_concat_out = block.create_var( + name=paddle.fluid.unique_name.generate_with_ignorable_key( + ".".join([helper.name, 'tmp']) + ), + dtype=tensor.dtype, + shape=None, + lod_level=tensor.lod_level, + type=tensor.type, + persistable=False, + stop_gradient=False, + ) + cur_rank = paddle.distributed.get_rank() + c_concat_op = block._insert_op( + idx + idx_offset, + type=op_type, + inputs={'X': [tensor]}, + outputs={'Out': [c_concat_out]}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True, + 'nranks': group.nranks, + 'op_role': op_role, + 'rank': group.ranks.index(cur_rank) if cur_rank in ranks else 0, + }, + ) + c_concat_op._set_attr('op_namescope', "/auto_parallel/reshard") + return c_concat_out + @staticmethod def concat_partitions_with_op( partition_tensor_list, tensor, partition_index, block, idx, op_role @@ -1535,7 +1611,7 @@ class Resharder: ) ) - # in the same process group, it will use allgahther and slice op. + # In the same process group, it will use allgahther and slice op. else: # NOTE: It just supports even partition scene. partition_index_list = [] @@ -1599,21 +1675,37 @@ class Resharder: if not serial else dist_tensor.local_sizes(rank=process) ) - op_desc_seq[process] = ( - [ - AllGatherOpDesc( - group=group, - shape=allgather_shape, - is_bool=(source_tensor.dtype == paddle.bool), - ), - ConcatOpDesc( - partition_index_list=all_partition_index_list - ), - slice_op_desc, + # c_concat pass + if ( + target_dims_mapping.count(-1) + == len(target_dims_mapping) + and source_dims_mapping[:-1].count(-1) + == len(source_dims_mapping[:-1]) + and source_dims_mapping[-1] != -1 + ): + op_desc_seq[process] = [ + AllGatherConcatOpDesc( + group=group, shape=allgather_shape + ) ] - if len(group) > 1 - else [slice_op_desc] - ) + else: + op_desc_seq[process] = ( + [ + AllGatherOpDesc( + group=group, + shape=allgather_shape, + is_bool=( + source_tensor.dtype == paddle.bool + ), + ), + ConcatOpDesc( + partition_index_list=all_partition_index_list + ), + slice_op_desc, + ] + if len(group) > 1 + else [slice_op_desc] + ) return op_desc_seq @@ -1850,27 +1942,41 @@ class Resharder: ) idx = idx_list[0] - elif isinstance(op_desc, SliceOpDesc): - assert ( - len(partition_tensor_list) == 1 or not partition_tensor_list - ) - to_slice_tensor = ( - partition_tensor_list[0][0] - if len(partition_tensor_list) == 1 - else source_tensor - ) - new_name = unique_name.generate(var_name + "@RESHARD") - target_tensor = Inserter.insert_slice_op( - block, - idx, - to_slice_tensor, - starts=op_desc.starts, - ends=op_desc.ends, - axes=op_desc.axes, - new_var_name=new_name, - op_role=reshard_op.attr('op_role'), - ) + elif isinstance(op_desc, SliceOpDesc) or isinstance( + op_desc, AllGatherConcatOpDesc + ): + target_tensor = None + if isinstance(op_desc, SliceOpDesc): + assert ( + len(partition_tensor_list) == 1 + or not partition_tensor_list + ) + to_slice_tensor = ( + partition_tensor_list[0][0] + if len(partition_tensor_list) == 1 + else source_tensor + ) + new_name = unique_name.generate(var_name + "@RESHARD") + target_tensor = Inserter.insert_slice_op( + block, + idx, + to_slice_tensor, + starts=op_desc.starts, + ends=op_desc.ends, + axes=op_desc.axes, + new_var_name=new_name, + op_role=reshard_op.attr('op_role'), + ) + else: + target_tensor = Inserter.insert_c_concat_op( + block, + idx, + source_tensor, + op_desc.group, + reshard_op.attr('op_role'), + ) + assert target_tensor is not None process_mesh = dist_attr[0] dims_mapping = dist_attr[1] diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py index add1d50180ab039b696ef104f4276eba1b04596f..42a2e6ff798ffb9b3b38387053822ab67c660307 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -304,6 +304,60 @@ class TestMLPReshard(unittest.TestCase): # the x should not be slice self.assertTrue(check_allgather(partitioned_main_prog)) + def test_c_concat(self): + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"]) + with static.program_guard(train_program, startup_program): + x = paddle.static.data(name="x", shape=[4, 4], dtype='float32') + x = auto.shard_tensor(x, process_mesh, [None, "x"]) + w = paddle.static.data(name="w", shape=[4, 4], dtype='float32') + w = auto.shard_tensor(w, process_mesh, [None, None]) + + y = paddle.distributed.shard_op( + paddle.matmul, process_mesh, [[None, None], [None, None]] + )(x, w) + + rank_id = 0 + dist_context = DistributedContext() + dist_strategy = fleet.DistributedStrategy() + partitioner = Partitioner(dist_context, rank_id) + completer = Completer(dist_context) + complete_train_program = completer.complete_forward_annotation( + train_program + ) + dist_context.block_state.parse_forward_blocks(complete_train_program) + ( + partitioned_main_prog, + partitioned_startup_prog, + partitioned_params_grads, + ) = partitioner.partition(complete_train_program, startup_program, []) + + # test estimator + cluster = Cluster() + cluster.gen_default_config_cluster(device_count=2) + cost_estimator = CostEstimator(train_program, cluster) + global_cost = cost_estimator.estimate(dist_context) + max_memory = cost_estimator._estimate_max_memory_by_dist_op( + dist_context + ) + # test cache + global_cost = cost_estimator.estimate(dist_context) + max_memory = cost_estimator._estimate_max_memory_by_dist_op( + dist_context + ) + assert global_cost.time >= 0 + assert max_memory > 0 + + resharder = Resharder( + partitioned_main_prog, + partitioned_startup_prog, + rank_id, + dist_context, + partitioned_params_grads, + ) + resharder.reshard() + if __name__ == "__main__": unittest.main()