diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 6a8ea7d1daab1a7d6a65ead4b9153f3f5923aa8a..a5260ac3b2ef1ba81f1fbbea202d226590ec1f62 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -738,14 +738,23 @@ void* GetPointerByOffset(void* raw_pointer, } else if (type == experimental::DataType::FLOAT64) { return reinterpret_cast(reinterpret_cast(raw_pointer) + offset); + } else if (type == experimental::DataType::FLOAT16) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); } else if (type == experimental::DataType::INT32) { return reinterpret_cast(reinterpret_cast(raw_pointer) + offset); } else if (type == experimental::DataType::INT64) { return reinterpret_cast(reinterpret_cast(raw_pointer) + offset); - } else if (type == experimental::DataType::FLOAT16) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + + } else if (type == experimental::DataType::INT8) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == experimental::DataType::UINT8) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == experimental::DataType::BOOL) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + offset); } else { PADDLE_THROW(platform::errors::Unimplemented( diff --git a/paddle/phi/kernels/cpu/concat_kernel.cc b/paddle/phi/kernels/cpu/concat_kernel.cc index 6be825d4ef14e8e9aabf9c1b5b804c3ff5a18347..a80c9db43c8b46ebc5d5a5bacf8b99bed475cd29 100644 --- a/paddle/phi/kernels/cpu/concat_kernel.cc +++ b/paddle/phi/kernels/cpu/concat_kernel.cc @@ -124,6 +124,8 @@ PD_REGISTER_KERNEL(concat, int64_t, int, uint8_t, + int8_t, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/concat_kernel.cu b/paddle/phi/kernels/gpu/concat_kernel.cu index accb1cc3d77e3ccd14b4d7808b781cf255eddd06..6d32205b0bb643491039cc8c43e55b59893de5d9 100644 --- a/paddle/phi/kernels/gpu/concat_kernel.cu +++ b/paddle/phi/kernels/gpu/concat_kernel.cu @@ -121,6 +121,7 @@ PD_REGISTER_KERNEL(concat, int64_t, int, uint8_t, + int8_t, phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 74e350b4a537c94012fd13af4ca2c7b11cad84b7..9900195c2030f2e2b14c31152b32ba037bd15447 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -60,21 +60,18 @@ class ReduceOp: Examples: .. code-block:: python - import numpy as np + # required: distributed import paddle - from paddle.distributed import ReduceOp - from paddle.distributed import init_parallel_env + import paddle.distributed as dist - paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id) - init_parallel_env() - if paddle.distributed.ParallelEnv().local_rank == 0: - np_data = np.array([[4, 5, 6], [4, 5, 6]]) + dist.init_parallel_env() + if dist.get_rank() == 0: + data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]]) else: - np_data = np.array([[1, 2, 3], [1, 2, 3]]) - data = paddle.to_tensor(np_data) - paddle.distributed.all_reduce(data, op=ReduceOp.SUM) - out = data.numpy() - # [[5, 7, 9], [5, 7, 9]] + data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) + dist.all_reduce(data, op=dist.ReduceOp.SUM) + print(data) + # [[5, 7, 9], [5, 7, 9]] (2 GPUs) """ SUM = 0 MAX = 1 @@ -589,15 +586,16 @@ def destroy_process_group(group=None): # required: distributed import paddle + import paddle.distributed as dist - paddle.distributed.init_parallel_env() - group = paddle.distributed.new_group([0, 1]) + dist.init_parallel_env() + group = dist.new_group([0, 1]) - paddle.distributed.destroy_process_group(group) - print(paddle.distributed.is_initialized()) + dist.destroy_process_group(group) + print(dist.is_initialized()) # True - paddle.distributed.destroy_process_group() - print(paddle.distributed.is_initialized()) + dist.destroy_process_group() + print(dist.is_initialized()) # False """ @@ -690,8 +688,8 @@ def broadcast(tensor, src, group=None, use_calc_stream=True): """ Broadcast a tensor from the source to all others. - As shown below, 4 GPUs each start 4 processes and GPU0 owns data 0. Through broadcast operator, - the data 0 will be sent to all GPUs from GPU0. + As shown below, one process is started with a GPU and GPU0 owns data 0. Through broadcast operator, + data 0 will be sent to all GPUs from GPU0. .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/broadcast.png :width: 800 @@ -699,8 +697,8 @@ def broadcast(tensor, src, group=None, use_calc_stream=True): :align: center Args: - tensor (Tensor): The Tensor to send if current rank is the source, or the tensor to receive otherwise. Its data type - should be float16, float32, float64, int32 or int64. + tensor (Tensor): The Tensor to send if current rank is the source, or the Tensor to receive otherwise. Its data type + should be float16, float32, float64, int32, int64, int8, uint8 or bool. src (int): The source rank. group (Group): The group instance return by new_group or None for global default group. use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False). @@ -713,20 +711,17 @@ def broadcast(tensor, src, group=None, use_calc_stream=True): .. code-block:: python # required: distributed - import numpy as np import paddle - from paddle.distributed import init_parallel_env + import paddle.distributed as dist - paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id) - init_parallel_env() - if paddle.distributed.ParallelEnv().local_rank == 0: - np_data = np.array([[4, 5, 6], [4, 5, 6]]) + dist.init_parallel_env() + if dist.get_rank() == 0: + data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]]) else: - np_data = np.array([[1, 2, 3], [1, 2, 3]]) - data = paddle.to_tensor(np_data) - paddle.distributed.broadcast(data, 1) - out = data.numpy() - # [[1, 2, 3], [1, 2, 3]] + data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) + dist.broadcast(data, src=1) + print(data) + # [[1, 2, 3], [1, 2, 3]] (2 GPUs) """ if group is not None and not group.is_member(): @@ -756,9 +751,10 @@ def broadcast(tensor, src, group=None, use_calc_stream=True): 'ring_id', ring_id) op_type = 'c_broadcast' - check_variable_and_dtype( - tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], - 'broadcast') + check_variable_and_dtype(tensor, 'tensor', [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8', + 'bool' + ], 'broadcast') helper = LayerHelper(op_type, **locals()) helper.append_op(type=op_type, @@ -800,15 +796,16 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True): # required: distributed import paddle - from paddle.distributed import init_parallel_env + import paddle.distributed as dist - paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id) - init_parallel_env() - if paddle.distributed.ParallelEnv().local_rank == 0: + dist.init_parallel_env() + if dist.get_rank() == 0: data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]]) else: data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) - paddle.distributed.all_reduce(data) + dist.all_reduce(data) + print(data) + # [[5, 7, 9], [5, 7, 9]] (2 GPUs) """ if group is not None and not group.is_member(): return @@ -871,8 +868,8 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True): def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True): """ - Reduce a tensor to the destination from all others. As shown below, 4 GPUs each start 4 processes and the data on each GPU is respresnted - by the GPU number. The destination of the reduce operator is GPU0 and the process is sum. Through reduce operator, + Reduce a tensor to the destination from all others. As shown below, one process is started with a GPU and the data of this process is represented + by its group rank. The destination of the reduce operator is GPU0 and the process is sum. Through reduce operator, the GPU0 will owns the sum of all data from all GPUs. .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/reduce.png @@ -882,7 +879,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True): Args: tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type - should be float16, float32, float64, int32 or int64. + should be float16, float32, float64, int32, int64, int8, uint8 or bool. dst (int): The destination rank id. op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM. group (Group): The group instance return by new_group or None for global default group. @@ -896,20 +893,18 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True): .. code-block:: python # required: distributed - import numpy as np import paddle - from paddle.distributed import init_parallel_env + import paddle.distributed as dist - paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id) - init_parallel_env() - if paddle.distributed.ParallelEnv().local_rank == 0: - np_data = np.array([[4, 5, 6], [4, 5, 6]]) + dist.init_parallel_env() + if dist.get_rank() == 0: + data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]]) else: - np_data = np.array([[1, 2, 3], [1, 2, 3]]) - data = paddle.to_tensor(np_data) - paddle.distributed.reduce(data, 0) - out = data.numpy() - # [[5, 7, 9], [5, 7, 9]] + data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) + dist.reduce(data, dst=0) + print(data) + # [[5, 7, 9], [5, 7, 9]] (2 GPUs, out for rank 0) + # [[1, 2, 3], [1, 2, 3]] (2 GPUs, out for rank 1) """ if group is not None and not group.is_member(): return @@ -952,9 +947,10 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True): raise ValueError("Unknown parameter: {}.".format(op)) op_type = 'c_reduce' - check_variable_and_dtype( - tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], - 'all_reduce') + check_variable_and_dtype(tensor, 'tensor', [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8', + 'bool' + ], 'reduce') if op == ReduceOp.SUM: op_type = 'c_reduce_sum' @@ -980,8 +976,8 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): """ Gather tensors from all participators and all get the result. As shown - below, 4 GPUs each starts 4 processes and the data on each GPU is represented - by the GPU number. Through the all_gather operator, each GPU will have data + below, one process is started with a GPU and the data of this process is represented + by its group rank. Through the all_gather operator, each GPU will have data from all GPUs. .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png @@ -1006,17 +1002,17 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): # required: distributed import paddle - from paddle.distributed import init_parallel_env + import paddle.distributed as dist - paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id) - init_parallel_env() + dist.init_parallel_env() tensor_list = [] - if paddle.distributed.ParallelEnv().local_rank == 0: - data1 = paddle.to_tensor([[4, 5, 6], [4, 5, 6]]) - paddle.distributed.all_gather(tensor_list, data1) + if dist.get_rank() == 0: + data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]]) else: - data2 = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) - paddle.distributed.all_gather(tensor_list, data2) + data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) + dist.all_gather(tensor_list, data) + print(tensor_list) + # [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs) """ if group is not None and not group.is_member(): return @@ -1126,15 +1122,15 @@ def all_gather_object(object_list, obj, group=None): import paddle import paddle.distributed as dist - paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id) dist.init_parallel_env() object_list = [] - if paddle.distributed.ParallelEnv().local_rank == 0: + if dist.get_rank() == 0: obj = {"foo": [1, 2, 3]} - paddle.distributed.all_gather_object(object_list, obj) else: obj = {"bar": [4, 5, 6]} - paddle.distributed.all_gather_object(object_list, obj) + dist.all_gather_object(object_list, obj) + print(object_list) + # [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs) """ assert in_dygraph_mode( ), "all_gather_object doesn't support static graph mode." @@ -1163,7 +1159,7 @@ def all_gather_object(object_list, obj, group=None): def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True): """ - Scatter a tensor to all participators. As shown below, 4 GPUs each start 4 processes and the source of the scatter + Scatter a tensor to all participators. As shown below, one process is started with a GPU and the source of the scatter is GPU0. Through scatter operator, the data in GPU0 will be sent to all GPUs averagely. .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/scatter.png @@ -1173,9 +1169,9 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True): Args: tensor (Tensor): The output Tensor. Its data type - should be float16, float32, float64, int32 or int64. + should be float16, float32, float64, int32, int64, int8, uint8 or bool. tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type - should be float16, float32, float64, int32 or int64. Default value is None. + should be float16, float32, float64, int32, int64, int8, uint8 or bool. Default value is None. src (int): The source rank id. Default value is 0. group (Group): The group instance return by new_group or None for global default group. use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False). @@ -1188,25 +1184,21 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True): .. code-block:: python # required: distributed - import numpy as np import paddle - from paddle.distributed import init_parallel_env + import paddle.distributed as dist - paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id) - init_parallel_env() - if paddle.distributed.ParallelEnv().local_rank == 0: - np_data1 = np.array([7, 8, 9]) - np_data2 = np.array([10, 11, 12]) - else: - np_data1 = np.array([1, 2, 3]) - np_data2 = np.array([4, 5, 6]) - data1 = paddle.to_tensor(np_data1) - data2 = paddle.to_tensor(np_data2) - if paddle.distributed.ParallelEnv().local_rank == 0: - paddle.distributed.scatter(data1, src=1) + dist.init_parallel_env() + if dist.get_rank() == 0: + data1 = paddle.to_tensor([7, 8, 9]) + data2 = paddle.to_tensor([10, 11, 12]) + dist.scatter(data1, src=1) else: - paddle.distributed.scatter(data1, tensor_list=[data1, data2], src=1) - out = data1.numpy() + data1 = paddle.to_tensor([1, 2, 3]) + data2 = paddle.to_tensor([4, 5, 6]) + dist.scatter(data1, tensor_list=[data1, data2], src=1) + print(data1, data2) + # [1, 2, 3] [10, 11, 12] (2 GPUs, out for rank 0) + # [4, 5, 6] [4, 5, 6] (2 GPUs, out for rank 1) """ if group is not None and not group.is_member(): return @@ -1244,9 +1236,10 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True): use_calc_stream, 'ring_id', ring_id, 'nranks', nranks, 'root', gsrc) op_type = 'c_scatter' - check_variable_and_dtype( - tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], - 'scatter') + check_variable_and_dtype(tensor, 'tensor', [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8', + 'bool' + ], 'scatter') helper = LayerHelper(op_type, **locals()) helper.append_op(type=op_type, inputs={'X': [temp]}, @@ -2014,7 +2007,7 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): Args: in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type - should be float16, float32, float64, int32 or int64. + should be float16, float32, float64, int32, int64, int8, uint8 or bool. out_tensor_list (list): A list of output Tensors. The data type of its elements should be the same as the data type of the input Tensors. group (Group, optional): The group instance return by new_group or None for global default group. Default: None. @@ -2027,29 +2020,29 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): .. code-block:: python # required: distributed - import numpy as np import paddle - from paddle.distributed import init_parallel_env - - init_parallel_env() + import paddle.distributed as dist + + dist.init_parallel_env() out_tensor_list = [] - if paddle.distributed.ParallelEnv().rank == 0: - np_data1 = np.array([[1, 2, 3], [4, 5, 6]]) - np_data2 = np.array([[7, 8, 9], [10, 11, 12]]) + if dist.get_rank() == 0: + data1 = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) + data2 = paddle.to_tensor([[7, 8, 9], [10, 11, 12]]) else: - np_data1 = np.array([[13, 14, 15], [16, 17, 18]]) - np_data2 = np.array([[19, 20, 21], [22, 23, 24]]) - data1 = paddle.to_tensor(np_data1) - data2 = paddle.to_tensor(np_data2) - paddle.distributed.alltoall([data1, data2], out_tensor_list) - # out for rank 0: [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]] - # out for rank 1: [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]] + data1 = paddle.to_tensor([[13, 14, 15], [16, 17, 18]]) + data2 = paddle.to_tensor([[19, 20, 21], [22, 23, 24]]) + dist.alltoall([data1, data2], out_tensor_list) + print(out_tensor_list) + # [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]] (2 GPUs, out for rank 0) + # [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]] (2 GPUs, out for rank 1) """ if group is not None and not group.is_member(): return if in_dygraph_mode(): group = _get_default_group() if group is None else group + backend = _group_map_backend[group] + assert backend != 'gloo', ("backend gloo is not supported yet") else: ring_id = 0 if group is None else group.id @@ -2114,7 +2107,7 @@ def alltoall_single(in_tensor, ``alltoall_single`` is only supported in eager mode. Args: - in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32 or int64. + in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32, int64, int8, uint8 or bool. out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor. in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor`` must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None. @@ -2137,35 +2130,36 @@ def alltoall_single(in_tensor, rank = dist.get_rank() size = dist.get_world_size() - # case 1 - input = paddle.arange(2, dtype='int64') + rank * 2 - # input for rank 0: [0, 1] - # input for rank 1: [2, 3] - + # case 1 (2 GPUs) + data = paddle.arange(2, dtype='int64') + rank * 2 + # data for rank 0: [0, 1] + # data for rank 1: [2, 3] output = paddle.empty([2], dtype='int64') - dist.alltoall_single(input, output) + dist.alltoall_single(data, output) + print(output) # output for rank 0: [0, 2] # output for rank 1: [1, 3] - # case 2 + # case 2 (2 GPUs) in_split_sizes = [i + 1 for i in range(size)] - # in_split_sizes for rank 0: [1, 2] and for rank 1: [1, 2] + # in_split_sizes for rank 0: [1, 2] + # in_split_sizes for rank 1: [1, 2] out_split_sizes = [rank + 1 for i in range(size)] - # out_split_sizes for rank 0: [1, 1] and for rank 1: [2, 2] - - input = paddle.ones([sum(in_split_sizes), size], dtype='float32') * rank - # input for rank 0: [[0., 0.], [0., 0.], [0., 0.]] - # input for rank 1: [[1., 1.], [1., 1.], [1., 1.]] + # out_split_sizes for rank 0: [1, 1] + # out_split_sizes for rank 1: [2, 2] + data = paddle.ones([sum(in_split_sizes), size], dtype='float32') * rank + # data for rank 0: [[0., 0.], [0., 0.], [0., 0.]] + # data for rank 1: [[1., 1.], [1., 1.], [1., 1.]] output = paddle.empty([(rank + 1) * size, size], dtype='float32') - group = dist.new_group([0, 1]) - task = dist.alltoall_single(input, + task = dist.alltoall_single(data, output, in_split_sizes, out_split_sizes, use_calc_stream=False, group=group) task.wait() + print(output) # output for rank 0: [[0., 0.], [1., 1.]] # output for rank 1: [[0., 0.], [0., 0.], [1., 1.], [1., 1.]] @@ -2177,6 +2171,9 @@ def alltoall_single(in_tensor, # _check_single_tensor group = _get_default_group() if group is None else group + backend = _group_map_backend[group] + assert backend != 'gloo', ("backend gloo is not supported yet") + in_split_sizes = [] if in_split_sizes is None else in_split_sizes out_split_sizes = [] if out_split_sizes is None else out_split_sizes @@ -2199,7 +2196,7 @@ def send(tensor, dst=0, group=None, use_calc_stream=True): Args: tensor (Tensor): The Tensor to send. Its data type - should be float16, float32, float64, int32 or int64. + should be float16, float32, float64, int32, int64, int8, uint8 or bool. dst (int): The destination rank id. group (Group, optional): The group instance return by new_group or None for global default group. Default: None. use_calc_stream (bool, optional): Whether to use calculate stream or communication stream. Default: True. @@ -2212,22 +2209,25 @@ def send(tensor, dst=0, group=None, use_calc_stream=True): # required: distributed import paddle - from paddle.distributed import init_parallel_env + import paddle.distributed as dist - init_parallel_env() - if paddle.distributed.ParallelEnv().rank == 0: + dist.init_parallel_env() + if dist.get_rank() == 0: data = paddle.to_tensor([7, 8, 9]) - paddle.distributed.send(data, dst=1) + dist.send(data, dst=1) else: - data = paddle.to_tensor([1,2,3]) - paddle.distributed.recv(data, src=0) - out = data.numpy() + data = paddle.to_tensor([1, 2, 3]) + dist.recv(data, src=0) + print(data) + # [7, 8, 9] (2 GPUs) """ if group is not None and not group.is_member(): return dst = _get_group_rank(dst, group) if in_dygraph_mode(): group = _get_default_group() if group is None else group + backend = _group_map_backend[group] + assert backend != 'gloo', ("backend gloo is not supported yet") task = group.process_group.send(tensor, dst) if use_calc_stream: task.wait() @@ -2261,7 +2261,7 @@ def recv(tensor, src=0, group=None, use_calc_stream=True): Args: tensor (Tensor): The Tensor to receive. Its data type - should be float16, float32, float64, int32 or int64. + should be float16, float32, float64, int32, int64, int8, uint8 or bool. src (int): The source rank id. group (Group, optional): The group instance return by new_group or None for global default group. Default: None. use_calc_stream (bool, optional): Whether to use calculate stream or communication stream. Default: True. @@ -2274,16 +2274,17 @@ def recv(tensor, src=0, group=None, use_calc_stream=True): # required: distributed import paddle - from paddle.distributed import init_parallel_env + import paddle.distributed as dist - init_parallel_env() - if paddle.distributed.ParallelEnv().rank == 0: + dist.init_parallel_env() + if dist.get_rank() == 0: data = paddle.to_tensor([7, 8, 9]) - paddle.distributed.send(data, dst=1) + dist.send(data, dst=1) else: - data = paddle.to_tensor([1,2,3]) - paddle.distributed.recv(data, src=0) - out = data.numpy() + data = paddle.to_tensor([1, 2, 3]) + dist.recv(data, src=0) + print(data) + # [7, 8, 9] (2 GPUs) """ if group is not None and not group.is_member(): return @@ -2291,6 +2292,8 @@ def recv(tensor, src=0, group=None, use_calc_stream=True): src = _get_group_rank(src, group) if in_dygraph_mode(): group = _get_default_group() if group is None else group + backend = _group_map_backend[group] + assert backend != 'gloo', ("backend gloo is not supported yet") task = group.process_group.recv(tensor, src) if use_calc_stream: task.wait() @@ -2340,7 +2343,7 @@ def isend(tensor, dst, group=None): Args: tensor (Tensor): The Tensor to send. Its data type - should be float16, float32, float64, int32 or int64. + should be float16, float32, float64, int32, int64, int8, uint8 or bool. dst (int): The destination rank. group (Group, optional): The group instance return by new_group or None for global default group. Default: None. @@ -2358,21 +2361,15 @@ def isend(tensor, dst, group=None): import paddle.distributed as dist dist.init_parallel_env() - rank = dist.get_rank() - world_size = dist.get_world_size() - - if rank == 0: + if dist.get_rank() == 0: data = paddle.to_tensor([7, 8, 9]) - task = paddle.distributed.isend(data, dst=1) + task = dist.isend(data, dst=1) else: data = paddle.to_tensor([1, 2, 3]) - task = paddle.distributed.irecv(data, src=0) - + task = dist.irecv(data, src=0) task.wait() - print(data) - # paddle.tensor([7, 8, 9]) # Rank-0 - # paddle.tensor([7, 8, 9]) # Rank-1 + # [7, 8, 9] (2 GPUs) """ _check_single_tensor(tensor, "tensor") @@ -2381,6 +2378,8 @@ def isend(tensor, dst, group=None): if in_dygraph_mode(): group = _get_default_group() if group is None else group + backend = _group_map_backend[group] + assert backend != 'gloo', ("backend gloo is not supported yet") group_dst_rank = group.get_group_rank(dst) assert group_dst_rank >= 0, ("dst rank out of group, need global rank") return group.process_group.send(tensor, group_dst_rank) @@ -2394,12 +2393,12 @@ def irecv(tensor, src=None, group=None): Args: tensor (Tensor): The Tensor to receive. Its data type - should be float16, float32, float64, int32 or int64. + should be float16, float32, float64, int32, int64, int8, uint8 or bool. src (int): The source rank id. group (Group, optional): The group instance return by new_group or None for global default group. Default: None. Returns: - A distributed task object. + A distributed task object. Warning: This API only supports the dygraph mode. @@ -2412,21 +2411,15 @@ def irecv(tensor, src=None, group=None): import paddle.distributed as dist dist.init_parallel_env() - rank = dist.get_rank() - world_size = dist.get_world_size() - - if rank == 0: + if dist.get_rank() == 0: data = paddle.to_tensor([7, 8, 9]) - task = paddle.distributed.isend(data, dst=1) + task = dist.isend(data, dst=1) else: data = paddle.to_tensor([1, 2, 3]) - task = paddle.distributed.irecv(data, src=0) - + task = dist.irecv(data, src=0) task.wait() - print(data) - # paddle.tensor([7, 8, 9]) # Rank-0 - # paddle.tensor([7, 8, 9]) # Rank-1 + # [7, 8, 9] (2 GPUs) """ _check_single_tensor(tensor, "tensor") if group is not None and not group.is_member(): @@ -2434,6 +2427,8 @@ def irecv(tensor, src=None, group=None): if in_dygraph_mode(): group = _get_default_group() if group is None else group + backend = _group_map_backend[group] + assert backend != 'gloo', ("backend gloo is not supported yet") group_src_rank = group.get_group_rank(src) assert group_src_rank >= 0, ("src rank out of group, need global rank") return group.process_group.recv(tensor, group_src_rank) @@ -2581,8 +2576,9 @@ def reduce_scatter(tensor, Reduces, then scatters a list of tensors to all processes in a group Args: - tensor (Tensor): Output tensor. - tensor_list (list[Tensor]): List of tensors to reduce and scatter. + tensor (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool. + tensor_list (list[Tensor]): List of tensors to reduce and scatter. Every element in the list must be a Tensor whose data type + should be float16, float32, float64, int32, int64, int8, uint8 or bool. op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM. group (Group, optional): The group instance return by new_group or None for global default group. Default: None. @@ -2604,24 +2600,16 @@ def reduce_scatter(tensor, import paddle.distributed as dist dist.init_parallel_env() - rank = dist.get_rank() - world_size = dist.get_world_size() - - if rank == 0: - t1 = paddle.to_tensor([0, 1]) - t2 = paddle.to_tensor([2, 3]) + if dist.get_rank() == 0: + data1 = paddle.to_tensor([0, 1]) + data2 = paddle.to_tensor([2, 3]) else: - t1 = paddle.to_tensor([4, 5]) - t2 = paddle.to_tensor([6, 7]) - - tensor_list = [t1, t2] - - output = paddle.empty(shape=[2], dtype=tensor_list[0].dtype) - dist.reduce_scatter(output, tensor_list) - - print(output) - # [4, 6] # Rank-0 - # [8, 10] # Rank-1 + data1 = paddle.to_tensor([4, 5]) + data2 = paddle.to_tensor([6, 7]) + dist.reduce_scatter(data1, [data1, data2]) + print(data1) + # [4, 6] (2 GPUs, out for rank 0) + # [8, 10] (2 GPUs, out for rank 1) """ _check_single_tensor(tensor, "tensor") @@ -2633,6 +2621,8 @@ def reduce_scatter(tensor, if in_dygraph_mode(): op_type = _get_reduce_op(op, "reduce_scatter") group = _get_default_group() if group is None else group + backend = _group_map_backend[group] + assert backend != 'gloo', ("backend gloo is not supported yet") temp = paddle.concat(tensor_list, axis=0) task = group.process_group._reduce_scatter_base(tensor, temp, op_type) @@ -2654,8 +2644,9 @@ def _reduce_scatter_base(output, Reduces, then scatters a flattened tensor to all processes in a group. Args: - output (Tensor): Output tensor. - input (Tensor): Input tensor that is of size output tensor size times world size + output (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool. + input (Tensor): Input tensor that is of size output tensor size times world size. Its data type + should be float16, float32, float64, int32, int64, int8, uint8 or bool. op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM. group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. @@ -2669,23 +2660,19 @@ def _reduce_scatter_base(output, .. code-block:: python # required: distributed - import paddle import paddle.distributed as dist dist.init_parallel_env() rank = dist.get_rank() - world_size = dist.get_world_size() - - input = paddle.arange(4) + rank - # [0, 1, 2, 3] # Rank-0 - # [1, 2, 3, 4] # Rank-1 - - output = paddle.empty(shape=[2], dtype=input.dtype) - paddle.distributed.collective._reduce_scatter_base(output, input) + data = paddle.arange(4) + rank + # [0, 1, 2, 3] (2 GPUs, for rank 0) + # [1, 2, 3, 4] (2 GPUs, for rank 1) + output = paddle.empty(shape=[2], dtype=data.dtype) + dist.collective._reduce_scatter_base(output, data) print(output) - # [1, 3] # Rank-0 - # [5, 7] # Rank-1 + # [1, 3] (2 GPUs, out for rank 0) + # [5, 7] (2 GPUs, out for rank 1) """ _check_single_tensor(output, "output") diff --git a/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt b/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt index 3a30617ede88512e66eb37b7e629b1d80eac04bc..4431f16d7b6e527e3b7fc61ea63e9b53f497a8bd 100644 --- a/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt @@ -78,7 +78,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_collective_alltoall_api MODULES test_collective_alltoall_api ENVS "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") set_tests_properties(test_collective_alltoall_api - PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) bash_test_modules( @@ -92,6 +92,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) ) set_tests_properties(test_collective_alltoall_single PROPERTIES TIMEOUT "350") endif() +if((WITH_GPU OR WITH_ROCM) AND (LINUX)) + py_test_modules( + test_collective_alltoall_single_api MODULES + test_collective_alltoall_single_api ENVS + "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") + set_tests_properties(test_collective_alltoall_single_api + PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") +endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( test_collective_barrier_api MODULES test_collective_barrier_api ENVS @@ -117,7 +125,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_collective_broadcast_api MODULES test_collective_broadcast_api ENVS "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") set_tests_properties(test_collective_broadcast_api - PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( @@ -141,6 +149,13 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) set_tests_properties(test_collective_global_scatter PROPERTIES TIMEOUT "200" LABELS "RUN_TYPE=DIST") endif() +if((WITH_GPU OR WITH_ROCM) AND (LINUX)) + py_test_modules( + test_collective_isend_irecv_api MODULES test_collective_isend_irecv_api + ENVS "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") + set_tests_properties(test_collective_isend_irecv_api + PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") +endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( test_collective_optimizer MODULES test_collective_optimizer ENVS @@ -186,6 +201,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) ) set_tests_properties(test_collective_reduce_scatter PROPERTIES TIMEOUT "350") endif() +if((WITH_GPU OR WITH_ROCM) AND (LINUX)) + py_test_modules( + test_collective_reduce_scatter_api MODULES + test_collective_reduce_scatter_api ENVS + "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") + set_tests_properties(test_collective_reduce_scatter_api + PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") +endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( test_collective_scatter MODULES test_collective_scatter ENVS @@ -212,7 +235,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_collective_sendrecv_api MODULES test_collective_sendrecv_api ENVS "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") set_tests_properties(test_collective_sendrecv_api - PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( diff --git a/python/paddle/fluid/tests/unittests/collective/collective_alltoall_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/collective_alltoall_api_dygraph.py index b5994db5cb6c5170e08a6abce1c154f92abf288c..fcabaffd614d031fab8f61f533df7063a9cc2d78 100644 --- a/python/paddle/fluid/tests/unittests/collective/collective_alltoall_api_dygraph.py +++ b/python/paddle/fluid/tests/unittests/collective/collective_alltoall_api_dygraph.py @@ -45,12 +45,9 @@ class TestCollectiveAllToAllAPI(TestCollectiveAPIRunnerBase): with fluid.program_guard(main_prog, startup_program): tindata = paddle.to_tensor(indata) tindata = paddle.split(tindata, 2, axis=0) - tout_data = [] - paddle.distributed.alltoall(tindata, tout_data) - output_data = [] - for data in tout_data: - output_data.append(data.numpy()) - return output_data + toutdata = [] + paddle.distributed.alltoall(tindata, toutdata) + return [data.numpy() for data in toutdata] if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/collective/collective_alltoall_single_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/collective_alltoall_single_api_dygraph.py new file mode 100644 index 0000000000000000000000000000000000000000..5fac73989a6060816d5a790041d869d767748d41 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/collective_alltoall_single_api_dygraph.py @@ -0,0 +1,36 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import test_collective_api_base as test_base + + +class TestCollectiveAllToAllSingleAPI(test_base.TestCollectiveAPIRunnerBase): + + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank, indata=None): + with fluid.program_guard(main_prog, startup_program): + tindata = paddle.to_tensor(indata) + toutdata = paddle.to_tensor(indata) + paddle.distributed.alltoall_single(tindata, toutdata) + return [toutdata.numpy()] + + +if __name__ == "__main__": + test_base.runtime_main(TestCollectiveAllToAllSingleAPI, "alltoall") diff --git a/python/paddle/fluid/tests/unittests/collective/collective_broadcast_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/collective_broadcast_api_dygraph.py new file mode 100644 index 0000000000000000000000000000000000000000..29f0b74bb405b8ab8908d1c828690a7466ac5be9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/collective_broadcast_api_dygraph.py @@ -0,0 +1,36 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import unittest +import test_collective_api_base as test_base + + +class TestCollectiveBroadcastAPI(test_base.TestCollectiveAPIRunnerBase): + + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank, indata=None): + with fluid.program_guard(main_prog, startup_program): + tindata = paddle.to_tensor(indata) + paddle.distributed.broadcast(tindata, src=1) + return [tindata.numpy()] + + +if __name__ == "__main__": + test_base.runtime_main(TestCollectiveBroadcastAPI, "broadcast") diff --git a/python/paddle/fluid/tests/unittests/collective/collective_isend_irecv_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/collective_isend_irecv_api_dygraph.py new file mode 100644 index 0000000000000000000000000000000000000000..70437216a8f8565a9b7099fb1501b26a6a8339ce --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/collective_isend_irecv_api_dygraph.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import unittest +import test_collective_api_base as test_base + + +class TestCollectiveIsendIrecvAPI(test_base.TestCollectiveAPIRunnerBase): + + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank, indata=None): + with fluid.program_guard(main_prog, startup_program): + tindata = paddle.to_tensor(indata) + if rank == 0: + task = paddle.distributed.isend(tindata, dst=1) + else: + task = paddle.distributed.irecv(tindata, src=0) + task.wait() + return [tindata.numpy()] + + +if __name__ == "__main__": + test_base.runtime_main(TestCollectiveIsendIrecvAPI, "sendrecv") diff --git a/python/paddle/fluid/tests/unittests/collective/collective_reduce_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/collective_reduce_api_dygraph.py new file mode 100644 index 0000000000000000000000000000000000000000..257fc27ceee9f2a48c724749b1fcb24dd3b3f3a4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/collective_reduce_api_dygraph.py @@ -0,0 +1,36 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import unittest +import test_collective_api_base as test_base + + +class TestCollectiveReduceAPI(test_base.TestCollectiveAPIRunnerBase): + + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank, indata=None): + with fluid.program_guard(main_prog, startup_program): + tindata = paddle.to_tensor(indata) + paddle.distributed.reduce(tindata, dst=0) + return [tindata.numpy()] + + +if __name__ == "__main__": + test_base.runtime_main(TestCollectiveReduceAPI, "reduce") diff --git a/python/paddle/fluid/tests/unittests/collective/collective_reduce_scatter_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/collective_reduce_scatter_api_dygraph.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0eb6aef9d47a952daa13e400a2aa697bbc6cea --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/collective_reduce_scatter_api_dygraph.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import unittest +import test_collective_api_base as test_base + + +class TestCollectiveReduceScatterAPI(test_base.TestCollectiveAPIRunnerBase): + + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank, indata=None): + with fluid.program_guard(main_prog, startup_program): + tindata = paddle.to_tensor(indata) + subdata1, subdata2 = paddle.split(tindata, 2, axis=0) + paddle.distributed.reduce_scatter(subdata1, [subdata1, subdata2]) + return [subdata1.numpy()] + + +if __name__ == "__main__": + test_base.runtime_main(TestCollectiveReduceScatterAPI, "reduce_scatter") diff --git a/python/paddle/fluid/tests/unittests/collective/collective_scatter_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/collective_scatter_api_dygraph.py new file mode 100644 index 0000000000000000000000000000000000000000..f37f5653806ec8d30c923a6ea23d6d4e8cbd850e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/collective_scatter_api_dygraph.py @@ -0,0 +1,42 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import unittest +import test_collective_api_base as test_base + + +class TestCollectiveScatterAPI(test_base.TestCollectiveAPIRunnerBase): + + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank, indata=None): + with fluid.program_guard(main_prog, startup_program): + tindata = paddle.to_tensor(indata) + subdata1, subdata2 = paddle.split(tindata, 2, axis=0) + if rank == 0: + paddle.distributed.scatter(subdata1, src=1) + else: + paddle.distributed.scatter(subdata1, + tensor_list=[subdata1, subdata2], + src=1) + return [subdata1.numpy()] + + +if __name__ == "__main__": + test_base.runtime_main(TestCollectiveScatterAPI, "scatter") diff --git a/python/paddle/fluid/tests/unittests/collective/test_collective_alltoall_api.py b/python/paddle/fluid/tests/unittests/collective/test_collective_alltoall_api.py index 2fe1252846cb326430316026161f2d1f944a4d95..e079e99efebf5735f5f114f660b1313eac49dac9 100644 --- a/python/paddle/fluid/tests/unittests/collective/test_collective_alltoall_api.py +++ b/python/paddle/fluid/tests/unittests/collective/test_collective_alltoall_api.py @@ -31,10 +31,16 @@ class TestCollectiveAllToAllAPI(TestDistBase): self.check_with_place("collective_alltoall_api.py", "alltoall", "nccl") def test_alltoall_nccl_dygraph(self): - self.check_with_place("collective_alltoall_api_dygraph.py", - "alltoall", - "nccl", - static_mode="0") + dtypes_to_test = [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8', + 'bool' + ] + for dtype in dtypes_to_test: + self.check_with_place("collective_alltoall_api_dygraph.py", + "alltoall", + "nccl", + static_mode="0", + dtype=dtype) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/collective/test_collective_alltoall_single_api.py b/python/paddle/fluid/tests/unittests/collective/test_collective_alltoall_single_api.py new file mode 100644 index 0000000000000000000000000000000000000000..fb1e5e9da22ef10e18358fa45a64c6ddd33d3bce --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/test_collective_alltoall_single_api.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import test_collective_api_base as test_base + + +class TestCollectiveAllToAllSingleAPI(test_base.TestDistBase): + + def _setup_config(self): + pass + + def test_alltooall_single_nccl_dygraph(self): + dtypes_to_test = [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8', + 'bool' + ] + for dtype in dtypes_to_test: + self.check_with_place("collective_alltoall_single_api_dygraph.py", + "alltoall", + "nccl", + static_mode="0", + dtype=dtype) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_api.py b/python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_api.py index 289cb7152ac3663f9784d8ddcb55f5f295a74a89..2d21be144a68b63a13a7bf7236c1921b74af7714 100644 --- a/python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_api.py +++ b/python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_api.py @@ -35,6 +35,31 @@ class TestCollectiveBroadcastAPI(TestDistBase): self.check_with_place("collective_broadcast_api.py", "broadcast", "gloo", "0") + def test_broadcast_nccl_dygraph(self): + dtypes_to_test = [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8', + 'bool' + ] + for dtype in dtypes_to_test: + self.check_with_place("collective_broadcast_api_dygraph.py", + "broadcast", + "nccl", + static_mode="0", + dtype=dtype) + + def test_broadcast_gloo_dygraph(self): + dtypes_to_test = [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8', + 'bool' + ] + for dtype in dtypes_to_test: + self.check_with_place("collective_broadcast_api_dygraph.py", + "broadcast", + "gloo", + "0", + static_mode="0", + dtype=dtype) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/test_collective_isend_irecv_api.py b/python/paddle/fluid/tests/unittests/collective/test_collective_isend_irecv_api.py new file mode 100644 index 0000000000000000000000000000000000000000..f9613abc2406363b3c6525da2af22c6583bbbdd1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/test_collective_isend_irecv_api.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import test_collective_api_base as test_base + + +class TestCollectiveIsendIrecvAPI(test_base.TestDistBase): + + def _setup_config(self): + pass + + def test_isend_irecv_nccl_dygraph(self): + dtypes_to_test = [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8', + 'bool' + ] + for dtype in dtypes_to_test: + self.check_with_place("collective_isend_irecv_api_dygraph.py", + "sendrecv", + "nccl", + static_mode="0", + dtype=dtype) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/test_collective_reduce_api.py b/python/paddle/fluid/tests/unittests/collective/test_collective_reduce_api.py index 2da70f5a94dfd488a27beb582e208432f1a58bab..2fa84ea2ed7f187eca3b0eb52e6e80fd2652d0ce 100644 --- a/python/paddle/fluid/tests/unittests/collective/test_collective_reduce_api.py +++ b/python/paddle/fluid/tests/unittests/collective/test_collective_reduce_api.py @@ -38,6 +38,31 @@ class TestCollectiveReduceAPI(TestDistBase): def test_reduce_gloo(self): self.check_with_place("collective_reduce_api.py", "reduce", "gloo", "1") + def test_reduce_nccl_dygraph(self): + dtypes_to_test = [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8', + 'bool' + ] + for dtype in dtypes_to_test: + self.check_with_place("collective_reduce_api_dygraph.py", + "reduce", + "nccl", + static_mode="0", + dtype=dtype) + + def test_reduce_gloo_dygraph(self): + dtypes_to_test = [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8', + 'bool' + ] + for dtype in dtypes_to_test: + self.check_with_place("collective_reduce_api_dygraph.py", + "reduce", + "gloo", + "1", + static_mode="0", + dtype=dtype) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/test_collective_reduce_scatter_api.py b/python/paddle/fluid/tests/unittests/collective/test_collective_reduce_scatter_api.py new file mode 100644 index 0000000000000000000000000000000000000000..1d25527407f4533fe3f7225e436479695c214c95 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/test_collective_reduce_scatter_api.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import test_collective_api_base as test_base + + +class TestCollectiveReduceScatterAPI(test_base.TestDistBase): + + def _setup_config(self): + pass + + def test_reduce_scatter_nccl_dygraph(self): + dtypes_to_test = [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8', + 'bool' + ] + for dtype in dtypes_to_test: + self.check_with_place("collective_reduce_scatter_api_dygraph.py", + "reduce_scatter", + "nccl", + static_mode="0", + dtype=dtype) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/test_collective_scatter_api.py b/python/paddle/fluid/tests/unittests/collective/test_collective_scatter_api.py index 18c720c5628149dfef6e964ebc7fdc9225a5df3d..4093b8ed69093ebd5fccc53a732dd727e4d5a6e8 100644 --- a/python/paddle/fluid/tests/unittests/collective/test_collective_scatter_api.py +++ b/python/paddle/fluid/tests/unittests/collective/test_collective_scatter_api.py @@ -34,6 +34,31 @@ class TestCollectiveScatterAPI(TestDistBase): def test_scatter_nccl(self): self.check_with_place("collective_scatter_api.py", "scatter", "nccl") + def test_scatter_nccl_dygraph(self): + dtypes_to_test = [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8', + 'bool' + ] + for dtype in dtypes_to_test: + self.check_with_place("collective_scatter_api_dygraph.py", + "scatter", + "nccl", + static_mode="0", + dtype=dtype) + + def test_scatter_gloo_dygraph(self): + dtypes_to_test = [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8', + 'bool' + ] + for dtype in dtypes_to_test: + self.check_with_place("collective_scatter_api_dygraph.py", + "scatter", + "gloo", + "4", + static_mode="0", + dtype=dtype) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/test_collective_sendrecv_api.py b/python/paddle/fluid/tests/unittests/collective/test_collective_sendrecv_api.py index c0a14f7e2860cf1db7feef6875c2f46837602512..940d6ec709bf1fb9ab704dcb869dc8b9c9e898c3 100644 --- a/python/paddle/fluid/tests/unittests/collective/test_collective_sendrecv_api.py +++ b/python/paddle/fluid/tests/unittests/collective/test_collective_sendrecv_api.py @@ -33,11 +33,16 @@ class TestCollectiveSendRecvAPI(TestDistBase): # "nccl") def test_sendrecv_nccl_dygraph(self): - if paddle.fluid.core.is_compiled_with_cuda(): + dtypes_to_test = [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8', + 'bool' + ] + for dtype in dtypes_to_test: self.check_with_place("collective_sendrecv_api_dygraph.py", "sendrecv", "nccl", - static_mode='0') + static_mode="0", + dtype=dtype) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/collective/testslist.csv b/python/paddle/fluid/tests/unittests/collective/testslist.csv index bc341433b32f8ecdfdd0891aa2c815146c473556..fc08f861e907749139a3a3675a41c5d8618b8e84 100644 --- a/python/paddle/fluid/tests/unittests/collective/testslist.csv +++ b/python/paddle/fluid/tests/unittests/collective/testslist.csv @@ -8,23 +8,26 @@ test_collective_split_embedding,linux,rocm;gpu,300,DIST,../dist_test.sh,2,,PYTHO test_collective_allgather_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_allgather_object_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_allreduce_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., -test_collective_alltoall_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., +test_collective_alltoall_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_alltoall_single,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=.., +test_collective_alltoall_single_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_barrier_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_batch_isend_irecv,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=.., -test_collective_broadcast_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., +test_collective_broadcast_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_cpu_barrier_with_gloo,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_global_gather,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_global_scatter,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., +test_collective_isend_irecv_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_optimizer,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_process_group,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_reduce,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_reduce_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_reduce_scatter,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=.., +test_collective_reduce_scatter_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_scatter,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_scatter_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_sendrecv,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., -test_collective_sendrecv_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., +test_collective_sendrecv_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_split_col_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_split_embedding_none_divisible,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_split_row_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., diff --git a/python/paddle/fluid/tests/unittests/test_collective_api_base.py b/python/paddle/fluid/tests/unittests/test_collective_api_base.py index 4131239adf792153418fe11de9f3ba5890134325..21c9b172e9822590c8d8c2d9c73c5e2b7a2df896 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_api_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_api_base.py @@ -335,6 +335,12 @@ class TestDistBase(unittest.TestCase): need_result2 = need_result[need_result.shape[0] // 2:] np.testing.assert_allclose(tr0_out[0], need_result1, rtol=1e-05) np.testing.assert_allclose(tr1_out[0], need_result2, rtol=1e-05) + elif col_type == "reduce_scatter": + need_result = input1 + input2 + need_result1 = need_result[0:need_result.shape[0] // 2] + need_result2 = need_result[need_result.shape[0] // 2:] + np.testing.assert_allclose(tr0_out[0], need_result1, rtol=1e-05) + np.testing.assert_allclose(tr1_out[0], need_result2, rtol=1e-05) elif col_type == "allreduce": need_result = input1 + input2 np.testing.assert_allclose(tr0_out[0], diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 42e3bc9039f08a75c75fc0f3bc5439f21921a76d..5e05a93e905963f93cfdf1c2e696774a108242ad 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1015,7 +1015,7 @@ def concat(x, axis=0, name=None): Args: x (list|tuple): ``x`` is a Tensor list or Tensor tuple which is with data type bool, float16, - float32, float64, int32, int64, uint8. All the Tensors in ``x`` must have same data type. + float32, float64, int32, int64, int8, uint8. All the Tensors in ``x`` must have same data type. axis (int|Tensor, optional): Specify the axis to operate on the input Tensors. It's a scalar with data type int or a Tensor with shape [1] and data type int32 or int64. The effective range is [-R, R), where R is Rank(x). When ``axis < 0``, @@ -1073,10 +1073,10 @@ def concat(x, axis=0, name=None): check_type(input, 'input', (list, tuple, Variable), 'concat') if not isinstance(input, Variable): for id, x in enumerate(input): - check_variable_and_dtype( - x, 'input[' + str(id) + ']', - ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], - 'concat') + check_variable_and_dtype(x, 'input[' + str(id) + ']', [ + 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', + 'int8', 'unit8' + ], 'concat') if x.dtype != input[0].dtype: raise TypeError( "All the Tensors in the input must have the same data type."