diff --git a/paddle/fluid/distributed/collective/ProcessGroupGloo.cc b/paddle/fluid/distributed/collective/ProcessGroupGloo.cc index 91aa9e63ddf707e76070e633654112e75367ef99..b23942b114f3be59af1e544344ce45ab73a2fdf0 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupGloo.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupGloo.cc @@ -79,6 +79,15 @@ namespace distributed { case experimental::DataType::INT64: \ func(args); \ break; \ + case experimental::DataType::INT8: \ + func(args); \ + break; \ + case experimental::DataType::UINT8: \ + func(args); \ + break; \ + case experimental::DataType::BOOL: \ + func(args); \ + break; \ default: \ VLOG(0) << "Error: Unknown DataType."; \ exit(-1); \ diff --git a/paddle/fluid/operators/collective/c_allgather_op.cc b/paddle/fluid/operators/collective/c_allgather_op.cc index 3e6aaffe5738c9f8492d8c0046883a349fa1be66..bf458956703f8674e7631a351c630dc4040ad8d0 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.cc +++ b/paddle/fluid/operators/collective/c_allgather_op.cc @@ -94,4 +94,7 @@ REGISTER_OP_CPU_KERNEL(c_allgather, ops::CAllGatherOpCPUKernel, ops::CAllGatherOpCPUKernel, ops::CAllGatherOpCPUKernel, + ops::CAllGatherOpCPUKernel, + ops::CAllGatherOpCPUKernel, + ops::CAllGatherOpCPUKernel, ops::CAllGatherOpCPUKernel); diff --git a/paddle/fluid/operators/collective/c_allgather_op.cu.cc b/paddle/fluid/operators/collective/c_allgather_op.cu.cc index f82e518d0d7eda88edcc98411a70071581d01862..e9228a28dbac05131dd1ad0138a0b556e1a2f3dc 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/c_allgather_op.cu.cc @@ -100,5 +100,8 @@ REGISTER_OP_CUDA_KERNEL(c_allgather, ops::CAllGatherOpCUDAKernel, #endif ops::CAllGatherOpCUDAKernel, + ops::CAllGatherOpCUDAKernel, + ops::CAllGatherOpCUDAKernel, ops::CAllGatherOpCUDAKernel, + ops::CAllGatherOpCUDAKernel, ops::CAllGatherOpCUDAKernel); diff --git a/paddle/fluid/platform/device/gpu/nccl_helper.h b/paddle/fluid/platform/device/gpu/nccl_helper.h index e042c239b9aead7e09d8936cb4613c0234a99fe6..1ce8038f0e3e2b6110d37a83fbb3d7bd1f0e191d 100644 --- a/paddle/fluid/platform/device/gpu/nccl_helper.h +++ b/paddle/fluid/platform/device/gpu/nccl_helper.h @@ -55,6 +55,10 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) { return ncclFloat16; } else if (type == framework::proto::VarType::INT8) { return ncclInt8; + } else if (type == framework::proto::VarType::UINT8) { + return ncclUint8; + } else if (type == framework::proto::VarType::BOOL) { + return ncclUint8; #if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 } else if (type == framework::proto::VarType::BF16) { return ncclBfloat16; @@ -76,6 +80,12 @@ inline ncclDataType_t ToNCCLDataType(experimental::DataType type) { return ncclInt64; } else if (type == experimental::DataType::FLOAT16) { return ncclFloat16; + } else if (type == experimental::DataType::UINT8) { + return ncclUint8; + } else if (type == experimental::DataType::INT8) { + return ncclInt8; + } else if (type == experimental::DataType::BOOL) { + return ncclUint8; #if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000 } else if (type == experimental::DataType::BFLOAT16) { return ncclBfloat16; diff --git a/paddle/phi/kernels/cpu/split_kernel.cc b/paddle/phi/kernels/cpu/split_kernel.cc index 288cdd235aede1c9023f211ca84fe79d32222349..6034949cd817e9a790ecf71cda7eb75b4c4fbd13 100644 --- a/paddle/phi/kernels/cpu/split_kernel.cc +++ b/paddle/phi/kernels/cpu/split_kernel.cc @@ -72,5 +72,7 @@ PD_REGISTER_KERNEL(split, int64_t, int, bool, + uint8_t, + int8_t, phi::dtype::float16, phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/split_kernel.cu b/paddle/phi/kernels/gpu/split_kernel.cu index 5f2741e53ddb99f0f749a6d95ce8b1d858b33444..1e855905ae00b56205e2e05b042f1b264173bc1c 100644 --- a/paddle/phi/kernels/gpu/split_kernel.cu +++ b/paddle/phi/kernels/gpu/split_kernel.cu @@ -71,5 +71,7 @@ PD_REGISTER_KERNEL(split, int64_t, int, bool, + uint8_t, + int8_t, phi::dtype::float16, phi::dtype::bfloat16) {} diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index ab83e2929e4bc3027084405bf423a79b227a2483..a238126dc6c38e8b4aa162edc183852f8c701d55 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -31,6 +31,7 @@ from .collective import broadcast # noqa: F401 from .collective import all_reduce # noqa: F401 from .collective import reduce # noqa: F401 from .collective import all_gather # noqa: F401 +from .collective import all_gather_object # noqa: F401 from .collective import scatter # noqa: F401 from .collective import barrier # noqa: F401 from .collective import ReduceOp # noqa: F401 @@ -71,7 +72,8 @@ __all__ = [ # noqa "init_parallel_env", "gloo_init_parallel_env", "gloo_barrier", "gloo_release", "QueueDataset", "split", "CountFilterEntry", "ShowClickEntry", "get_world_size", "get_group", "all_gather", - "InMemoryDataset", "barrier", "all_reduce", "alltoall", "send", "reduce", - "recv", "ReduceOp", "wait", "get_rank", "ProbabilityEntry", "ParallelMode", - "is_initialized", "isend", "irecv", "reduce_scatter" + "all_gather_object", "InMemoryDataset", "barrier", "all_reduce", "alltoall", + "send", "reduce", "recv", "ReduceOp", "wait", "get_rank", + "ProbabilityEntry", "ParallelMode", "is_initialized", "isend", "irecv", + "reduce_scatter" ] diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 62b18298f11e05b16f0d09ab645592f307fc0b0b..e27ae8ef6c4a10a5bd5baf8d53c976482442eb90 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -14,6 +14,8 @@ import numpy as np import os +import pickle +import io from datetime import timedelta from ..fluid.layer_helper import LayerHelper from ..fluid.framework import Variable @@ -927,9 +929,9 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): Args: tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type - should be float16, float32, float64, int32 or int64. + should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128. tensor (Tensor): The Tensor to send. Its data type - should be float16, float32, float64, int32 or int64. + should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128. group (Group): The group instance return by new_group or None for global default group. use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False). Default to True. @@ -941,7 +943,6 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): .. code-block:: python # required: distributed - import numpy as np import paddle from paddle.distributed import init_parallel_env @@ -949,21 +950,26 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): init_parallel_env() tensor_list = [] if paddle.distributed.ParallelEnv().local_rank == 0: - np_data1 = np.array([[4, 5, 6], [4, 5, 6]]) - np_data2 = np.array([[4, 5, 6], [4, 5, 6]]) - data1 = paddle.to_tensor(np_data1) - data2 = paddle.to_tensor(np_data2) + data1 = paddle.to_tensor([[4, 5, 6], [4, 5, 6]]) paddle.distributed.all_gather(tensor_list, data1) else: - np_data1 = np.array([[1, 2, 3], [1, 2, 3]]) - np_data2 = np.array([[1, 2, 3], [1, 2, 3]]) - data1 = paddle.to_tensor(np_data1) - data2 = paddle.to_tensor(np_data2) + data2 = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) paddle.distributed.all_gather(tensor_list, data2) """ if group is not None and not group.is_member(): return + def convert_to_complex(list_of_tensor): + list_of_complex = [] + for tensor in list_of_tensor: + list_of_complex.append(paddle.as_complex(tensor)) + return list_of_complex + + is_input_complex = (tensor.dtype == paddle.complex64 + or tensor.dtype == paddle.complex128) + if is_input_complex: + tensor = paddle.as_real(tensor) + if in_dygraph_mode(): group = _get_default_group() if group is None else group if len(tensor_list) == 0: @@ -975,7 +981,11 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): task = group.process_group.all_gather(tensor, out) task.wait() tensor_list.clear() - tensor_list.extend(paddle.split(out, group.nranks, 0)) + list_of_tensor = paddle.split(out, group.nranks, 0) + if is_input_complex: + tensor_list.extend(convert_to_complex(list_of_tensor)) + else: + tensor_list.extend(list_of_tensor) return ring_id = 0 if group is None else group.id @@ -992,13 +1002,14 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): raise ValueError("The type of 'tensor_list' for all_gather " "should be list.") for elem in tensor_list: - check_variable_and_dtype( - elem, 'tensor_list', - ['float16', 'float32', 'float64', 'int32', 'int64'], - 'all_gather') - check_variable_and_dtype( - tensor, 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64'], 'all_gather') + check_variable_and_dtype(elem, 'tensor_list', [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'bool', + 'int8', 'uint8', 'complex64', 'complex128' + ], 'all_gather') + check_variable_and_dtype(tensor, 'tensor', [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'int8', + 'uint8', 'complex64', 'complex128' + ], 'all_gather') helper.append_op(type=op_type, inputs={'X': [tensor]}, outputs={'Out': [out]}, @@ -1008,7 +1019,69 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): 'nranks': nranks }) - tensor_list.extend(paddle.split(out, nranks, 0)) + list_of_tensor = paddle.split(out, nranks, 0) + if is_input_complex: + tensor_list.extend(convert_to_complex(list_of_tensor)) + else: + tensor_list.extend(list_of_tensor) + + +def _convert_object_to_tensor(obj): + _pickler = pickle.Pickler + f = io.BytesIO() + _pickler(f).dump(obj) + data = np.frombuffer(f.getvalue(), dtype=np.uint8) + tensor = paddle.to_tensor(data) + return tensor + + +def _convert_tensor_to_object(tensor): + _unpickler = pickle.Unpickler + return _unpickler(io.BytesIO(tensor.numpy())).load() + + +def all_gather_object(object_list, obj, group=None): + """ + + Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in. + + Args: + object_list (list): A list of output object. The datatype of every element in the list is same as the input obj. + obj (Any): The picklable object to send. + group (Group): The group instance return by new_group or None for global default group. + + Returns: + None. + + Warning: + This API only supports the dygraph mode. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id) + dist.init_parallel_env() + object_list = [] + if paddle.distributed.ParallelEnv().local_rank == 0: + obj = {"foo": [1, 2, 3]} + paddle.distributed.all_gather_object(object_list, obj) + else: + obj = {"bar": [4, 5, 6]} + paddle.distributed.all_gather_object(object_list, obj) + """ + assert in_dygraph_mode( + ), "all_gather_object doesn't support static graph mode." + + tensor = _convert_object_to_tensor(obj) + + tensor_list = [] + all_gather(tensor_list, tensor, group) + for tensor in tensor_list: + object_list.append(_convert_tensor_to_object(tensor)) def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True): diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index d8f4ecf6731de071ee2ee4f7a9d750dfb9aa8a8e..e34616f5feb3ad95140cf1ee0fadaa0bfd64fcdf 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -183,6 +183,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) list(REMOVE_ITEM TEST_OPS test_new_group_api) list(REMOVE_ITEM TEST_OPS test_collective_broadcast_api) list(REMOVE_ITEM TEST_OPS test_collective_allgather_api) + list(REMOVE_ITEM TEST_OPS test_collective_allgather_object_api) list(REMOVE_ITEM TEST_OPS test_collective_alltoall_api) list(REMOVE_ITEM TEST_OPS test_collective_global_gather) list(REMOVE_ITEM TEST_OPS test_collective_global_scatter) @@ -1598,7 +1599,9 @@ if(APPLE) endif() if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) - set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 120) + set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 300) + set_tests_properties(test_collective_allgather_object_api PROPERTIES TIMEOUT + 120) set_tests_properties(test_collective_alltoall_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_global_gather PROPERTIES TIMEOUT 200) set_tests_properties(test_collective_global_scatter PROPERTIES TIMEOUT 200) @@ -1629,6 +1632,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) test_new_group_api test_collective_broadcast_api test_collective_allgather_api + test_collective_allgather_object_api test_collective_alltoall_api test_collective_global_gather test_collective_global_scatter diff --git a/python/paddle/fluid/tests/unittests/collective_allgather_api.py b/python/paddle/fluid/tests/unittests/collective_allgather_api.py index d2a639d0294db33572d7dc18b3bb692b0cd476bd..b5dff7309ddfe839443fef6cb095af94875d80a7 100644 --- a/python/paddle/fluid/tests/unittests/collective_allgather_api.py +++ b/python/paddle/fluid/tests/unittests/collective_allgather_api.py @@ -30,28 +30,64 @@ import paddle.fluid.profiler as profiler import paddle.fluid.unique_name as nameGen from paddle.fluid import core import unittest +import pickle from multiprocessing import Process import paddle.fluid.layers as layers from functools import reduce -from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main +import test_collective_api_base as test_base paddle.enable_static() -class TestCollectiveAllgatherAPI(TestCollectiveAPIRunnerBase): +class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase): def __init__(self): self.global_ring_id = 0 - def get_model(self, main_prog, startup_program, rank): + def get_model(self, main_prog, startup_program, rank, dtype=None): + dtype = "float32" if dtype is None else dtype with fluid.program_guard(main_prog, startup_program): tensor_list = [] - tindata = layers.data(name="tindata", - shape=[10, 1000], - dtype='float32') + tindata = layers.data(name="tindata", shape=[10, 1000], dtype=dtype) paddle.distributed.all_gather(tensor_list, tindata) return tensor_list + def run_trainer(self, args): + train_prog = fluid.Program() + startup_prog = fluid.Program() + endpoints = args["endpoints"].split(",") + rank = args["trainerid"] + current_endpoint = args["currentendpoint"] + nranks = 2 + paddle.distributed.init_parallel_env() + if args['backend'] == 'nccl': + device_id = int(os.getenv("FLAGS_selected_gpus", "0")) + place = fluid.CUDAPlace( + device_id) #if args.use_gpu else fluid.CPUPlace() + elif args['backend'] == 'bkcl': + device_id = int(os.getenv("FLAGS_selected_xpus", "0")) + place = fluid.XPUPlace(device_id) + else: + place = fluid.CPUPlace() + indata = test_base.create_test_data(shape=(10, 1000), + dtype=args["dtype"], + seed=os.getpid()) + assert args[ + 'static_mode'] == 1, "collective_allgather_api only support static mode" + result = self.get_model(train_prog, + startup_prog, + rank, + dtype=args["dtype"]) + exe = fluid.Executor(place) + exe.run(startup_prog) + fetch_list = [] + for elem in result: + fetch_list.append(elem.name) + out = exe.run(train_prog, + feed={'tindata': indata}, + fetch_list=fetch_list) + sys.stdout.buffer.write(pickle.dumps(out)) + if __name__ == "__main__": - runtime_main(TestCollectiveAllgatherAPI, "allgather") + test_base.runtime_main(TestCollectiveAllgatherAPI, "allgather") diff --git a/python/paddle/fluid/tests/unittests/collective_allgather_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective_allgather_api_dygraph.py new file mode 100644 index 0000000000000000000000000000000000000000..d485fd23d957106a3da858be17ce34408dda45dc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_allgather_api_dygraph.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import unittest +import test_collective_api_base as test_base + + +class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase): + + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank, indata=None): + with fluid.program_guard(main_prog, startup_program): + tindata = paddle.to_tensor(indata) + tensor_list = [] + paddle.distributed.all_gather(tensor_list, tindata) + return [tensor.numpy() for tensor in tensor_list] + + +if __name__ == "__main__": + test_base.runtime_main(TestCollectiveAllgatherAPI, "allgather") diff --git a/python/paddle/fluid/tests/unittests/collective_allgather_object_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective_allgather_object_api_dygraph.py new file mode 100644 index 0000000000000000000000000000000000000000..b26556a749ae624297fe83146f4079fbc216fbdf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_allgather_object_api_dygraph.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import test_collective_api_base as test_base + + +class TestCollectiveAllgatherObjectAPI(test_base.TestCollectiveAPIRunnerBase): + + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank, indata=None): + with fluid.program_guard(main_prog, startup_program): + object_list = [] + paddle.distributed.all_gather_object(object_list, indata) + return object_list + + +if __name__ == "__main__": + test_base.runtime_main(TestCollectiveAllgatherObjectAPI, "allgather_object") diff --git a/python/paddle/fluid/tests/unittests/test_collective_allgather_api.py b/python/paddle/fluid/tests/unittests/test_collective_allgather_api.py index ebc52ded8bc7298eee2b779aca47ced2cec817ca..a01a96a0d6b29a4dcb116685bfebccd3819010e8 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_allgather_api.py +++ b/python/paddle/fluid/tests/unittests/test_collective_allgather_api.py @@ -28,12 +28,212 @@ class TestCollectiveAllgatherAPI(TestDistBase): pass def test_allgather_nccl(self): - self.check_with_place("collective_allgather_api.py", "allgather", - "nccl") + self.check_with_place("collective_allgather_api.py", + "allgather", + "nccl", + dtype="float16") + self.check_with_place("collective_allgather_api.py", + "allgather", + "nccl", + dtype="float32") + self.check_with_place("collective_allgather_api.py", + "allgather", + "nccl", + dtype="float64") + self.check_with_place("collective_allgather_api.py", + "allgather", + "nccl", + dtype="bool") + self.check_with_place("collective_allgather_api.py", + "allgather", + "nccl", + dtype="uint8") + self.check_with_place("collective_allgather_api.py", + "allgather", + "nccl", + dtype="int8") + self.check_with_place("collective_allgather_api.py", + "allgather", + "nccl", + dtype="int32") + self.check_with_place("collective_allgather_api.py", + "allgather", + "nccl", + dtype="int64") + self.check_with_place("collective_allgather_api.py", + "allgather", + "nccl", + dtype="complex64") + self.check_with_place("collective_allgather_api.py", + "allgather", + "nccl", + dtype="complex128") def test_allgather_gloo(self): - self.check_with_place("collective_allgather_api.py", "allgather", - "gloo", "3") + self.check_with_place("collective_allgather_api.py", + "allgather", + "gloo", + "3", + dtype="float16") + self.check_with_place("collective_allgather_api.py", + "allgather", + "gloo", + "3", + dtype="float32") + self.check_with_place("collective_allgather_api.py", + "allgather", + "gloo", + "3", + dtype="float64") + self.check_with_place("collective_allgather_api.py", + "allgather", + "gloo", + "3", + dtype="bool") + self.check_with_place("collective_allgather_api.py", + "allgather", + "gloo", + "3", + dtype="uint8") + self.check_with_place("collective_allgather_api.py", + "allgather", + "gloo", + "3", + dtype="int8") + self.check_with_place("collective_allgather_api.py", + "allgather", + "gloo", + "3", + dtype="int32") + self.check_with_place("collective_allgather_api.py", + "allgather", + "gloo", + "3", + dtype="int64") + self.check_with_place("collective_allgather_api.py", + "allgather", + "gloo", + "3", + dtype="complex64") + self.check_with_place("collective_allgather_api.py", + "allgather", + "gloo", + "3", + dtype="complex128") + + def test_allgatther_nccl_dygraph(self): + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "nccl", + static_mode="0", + dtype="float16") + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "nccl", + static_mode="0", + dtype="float32") + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "nccl", + static_mode="0", + dtype="float64") + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "nccl", + static_mode="0", + dtype="bool") + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "nccl", + static_mode="0", + dtype="uint8") + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "nccl", + static_mode="0", + dtype="int8") + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "nccl", + static_mode="0", + dtype="int32") + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "nccl", + static_mode="0", + dtype="int64") + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "nccl", + static_mode="0", + dtype="complex64") + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "nccl", + static_mode="0", + dtype="complex128") + + def test_allgather_gloo_dygraph(self): + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "gloo", + "3", + static_mode="0", + dtype="float16") + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "gloo", + "3", + static_mode="0", + dtype="float32") + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "gloo", + "3", + static_mode="0", + dtype="float64") + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "gloo", + "3", + static_mode="0", + dtype="bool") + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "gloo", + "3", + static_mode="0", + dtype="uint8") + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "gloo", + "3", + static_mode="0", + dtype="int8") + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "gloo", + "3", + static_mode="0", + dtype="int32") + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "gloo", + "3", + static_mode="0", + dtype="int64") + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "gloo", + "3", + static_mode="0", + dtype="complex64") + self.check_with_place("collective_allgather_api_dygraph.py", + "allgather", + "gloo", + "3", + static_mode="0", + dtype="complex128") if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_collective_allgather_object_api.py b/python/paddle/fluid/tests/unittests/test_collective_allgather_object_api.py new file mode 100644 index 0000000000000000000000000000000000000000..63ac93adf8dee1140f2d66355e1851c9c3e5d14b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_collective_allgather_object_api.py @@ -0,0 +1,53 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import test_collective_api_base as test_base + + +class TestCollectiveAllgatherObjectAPI(test_base.TestDistBase): + + def _setup_config(self): + pass + + def test_allgather_nccl(self): + self.check_with_place("collective_allgather_object_api_dygraph.py", + "allgather_object", + "nccl", + static_mode="0", + dtype="pylist") + self.check_with_place("collective_allgather_object_api_dygraph.py", + "allgather_object", + "nccl", + static_mode="0", + dtype="pydict") + + def test_allgather_gloo_dygraph(self): + self.check_with_place("collective_allgather_object_api_dygraph.py", + "allgather_object", + "gloo", + "3", + static_mode="0", + dtype="pylist") + self.check_with_place("collective_allgather_object_api_dygraph.py", + "allgather_object", + "gloo", + "3", + static_mode="0", + dtype="pydict") + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_collective_api_base.py b/python/paddle/fluid/tests/unittests/test_collective_api_base.py index 96f2d2bf4d504f911a40e2530b087f198f33de29..79457571aca90f6d0e879c55e9550e639df3e947 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_api_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_api_base.py @@ -31,9 +31,77 @@ import paddle.fluid.unique_name as nameGen from paddle.fluid import core +def create_bool_test_data(shape=None, seed=None): + if seed: + np.random.seed(seed) + data = np.random.choice([True, False], size=shape) + return data + + +def create_float_test_data(shape=None, dtype=None, seed=None): + if seed: + np.random.seed(seed) + data = np.random.random(shape).astype(dtype) + return data + + +def create_int_test_data(shape=None, dtype=None, seed=None): + if seed: + np.random.seed(seed) + data = np.random.randint(0, high=100, size=shape).astype(dtype) + return data + + +def create_complex_test_data(shape=None, dtype=None, seed=None): + if seed: + np.random.seed(seed) + data = np.random.random(shape).astype(dtype) + data.imag = np.random.random(shape) + return data + + +def create_pylist_test_data(shape=None, seed=None): + if seed: + np.random.seed(seed) + data = np.random.random(shape).tolist() + return data + + +def create_pydict_test_data(shape=None, seed=None): + if seed: + np.random.seed(seed) + key = [i for i in range(0, shape[0])] + value = np.random.random(shape).tolist() + data = dict(zip(key, value)) + return data + + +def create_test_data(shape=None, dtype=None, seed=None): + assert shape, "Shape should be specified" + if dtype == "float32" or dtype == "float16" or dtype == "float64": + return create_float_test_data(shape=shape, dtype=dtype, seed=seed) + elif dtype == "bool": + return create_bool_test_data(shape=shape, seed=seed) + elif dtype == "int32" or dtype == "int64" or dtype == "int8" or dtype == "uint8": + return create_int_test_data(shape=shape, dtype=dtype, seed=seed) + elif dtype == "complex64" or dtype == "complex128": + return create_complex_test_data(shape=shape, dtype=dtype, seed=seed) + elif dtype == "pylist": + return create_pylist_test_data(shape=shape, seed=seed) + elif dtype == "pydict": + return create_pydict_test_data(shape=shape, seed=seed) + else: + raise NotImplementedError("Unsupported dtype for creating test data.") + + class TestCollectiveAPIRunnerBase(object): - def get_model(self, train_prog, startup_prog, rank, indata=None): + def get_model(self, + train_prog, + startup_prog, + rank, + indata=None, + dtype=None): raise NotImplementedError( "get model should be implemented by child class.") @@ -54,8 +122,9 @@ class TestCollectiveAPIRunnerBase(object): place = fluid.XPUPlace(device_id) else: place = fluid.CPUPlace() - np.random.seed(os.getpid()) - indata = np.random.random((10, 1000)).astype("float32") + indata = create_test_data(shape=(10, 1000), + dtype=args["dtype"], + seed=os.getpid()) if args['static_mode']: result = self.get_model(train_prog, startup_prog, rank) exe = fluid.Executor(place) @@ -83,6 +152,7 @@ def runtime_main(test_class, col_type): args["backend"] = os.getenv("BACKEND") args["path_id"] = int(os.getenv("PATH_ID")) args["static_mode"] = int(os.getenv("STATIC_MODE")) + args["dtype"] = os.getenv("DTYPE") model.run_trainer(args) @@ -203,18 +273,22 @@ class TestDistBase(unittest.TestCase): static_mode="1", check_error_log=False, need_envs={}, - eager_mode=True): + eager_mode=True, + dtype=None): if backend == "nccl" or backend == "bkcl": with_gloo = '0' else: with_gloo = '1' required_envs = os.environ.copy() + dtype = "float32" if dtype is None else dtype additional_envs = { "NCCL_P2P_DISABLE": "1", "STATIC_MODE": static_mode, "PADDLE_WITH_GLOO": with_gloo, + "PADDLE_DISTRI_BACKEND": backend, "BACKEND": backend, - "PATH_ID": path_id + "PATH_ID": path_id, + "DTYPE": dtype } required_envs.update(additional_envs) required_envs.update(need_envs) @@ -234,16 +308,18 @@ class TestDistBase(unittest.TestCase): tr0_out, tr1_out, pid0, pid1 = self._run_cluster( model_file, required_envs) - np.random.seed(pid0) - input1 = np.random.random((10, 1000)) - np.random.seed(pid1) - input2 = np.random.random((10, 1000)) + input1 = create_test_data(shape=(10, 1000), dtype=dtype, seed=pid0) + input2 = create_test_data(shape=(10, 1000), dtype=dtype, seed=pid1) if col_type == "allgather": need_result = np.vstack((input1, input2)) tr_out0 = np.vstack((tr0_out[0], tr0_out[1])) tr_out1 = np.vstack((tr1_out[0], tr1_out[1])) self.assertTrue(np.allclose(tr_out0, need_result)) self.assertTrue(np.allclose(tr_out1, need_result)) + if col_type == "allgather_object": + need_result = [input1, input2] + self.assertEqual(need_result, tr0_out) + self.assertEqual(need_result, tr1_out) elif col_type == "broadcast": need_result = input2 self.assertTrue(np.allclose(tr0_out, need_result)) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 5b4f61ce872f53e5117e92f5aa3d8233e91dddb5..c170a6fb04e88d285ea4c0ae19f3c37d2c3e533a 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1737,7 +1737,7 @@ def split(x, num_or_sections, axis=0, name=None): Split the input tensor into multiple sub-Tensors. Args: - x (Tensor): A N-D Tensor. The data type is bool, float16, float32, float64, int32 or int64. + x (Tensor): A N-D Tensor. The data type is bool, float16, float32, float64, uint8, int8, int32 or int64. num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections`` indicates the number of equal sized sub-Tensors that the ``x`` will be divided into. If ``num_or_sections`` is a list or tuple, the length of it indicates the number of @@ -1814,9 +1814,10 @@ def split(x, num_or_sections, axis=0, name=None): _C_ops.split(input, out, *attrs) return out - check_variable_and_dtype( - input, 'input', - ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], 'split') + check_variable_and_dtype(input, 'input', [ + 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8', + 'int8' + ], 'split') check_type(num_or_sections, 'num_or_sections', (list, int, tuple), 'split') check_type(dim, 'dim', (int, Variable), 'split') if isinstance(dim, Variable):