diff --git a/paddle/fluid/distributed/collective/process_group.h b/paddle/fluid/distributed/collective/process_group.h index ad8ba19f8bae17152eb7082f86d308f20529dfb3..de67eaf2a5e875a598cb68ce004874e697542e4d 100644 --- a/paddle/fluid/distributed/collective/process_group.h +++ b/paddle/fluid/distributed/collective/process_group.h @@ -332,6 +332,30 @@ class ProcessGroup { GetBackendName())); } + virtual std::shared_ptr Gather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const GatherOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW( + phi::errors::Unimplemented("ProcessGroup%s does not support gather " + "with sync_op and use_calc_stream flag.", + GetBackendName())); + } + + virtual std::shared_ptr Gather( + std::vector* gather_tensors_ptr, + const phi::DenseTensor& in_tensor, + const GatherOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW( + phi::errors::Unimplemented("ProcessGroup%s does not support gather " + "with sync_op and use_calc_stream flag.", + GetBackendName())); + } + virtual std::shared_ptr Recv(phi::DenseTensor* tensor, int src_rank, bool sync_op, diff --git a/paddle/fluid/distributed/collective/process_group_nccl.cc b/paddle/fluid/distributed/collective/process_group_nccl.cc index c516b1cd31df2001b7fe42e0fa7b0db7bdb361c1..4653799401bbe73a2a9bd464b0ada2cfeb08b810 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.cc +++ b/paddle/fluid/distributed/collective/process_group_nccl.cc @@ -475,6 +475,71 @@ std::shared_ptr ProcessGroupNCCL::Scatter( use_calc_stream); } +std::shared_ptr ProcessGroupNCCL::Gather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const GatherOptions& opts, + bool sync_op, + bool use_calc_stream) { + std::vector partial_tensors; + if (rank_ == opts.root_rank) { + partial_tensors.reserve(size_); + size_t offset = 0; + size_t numel = out_tensor->numel() / size_; + for (auto i = 0; i < size_; i++) { + partial_tensors.push_back(GetPartialTensor(*out_tensor, offset, numel)); + offset += numel; + } + } + return Gather(&partial_tensors, in_tensor, opts, sync_op, use_calc_stream); +} + +std::shared_ptr ProcessGroupNCCL::Gather( + std::vector* gather_tensors_ptr, + const phi::DenseTensor& in_tensor, + const GatherOptions& opts, + bool sync_op, + bool use_calc_stream) { + auto& gather_tensors = *gather_tensors_ptr; + PADDLE_ENFORCE_GT(size_, + opts.root_rank, + phi::errors::InvalidArgument( + "root world size [%d] is less than root rank [%d]", + size_, + opts.root_rank)); + auto gather_func = [&](ncclComm_t comm, gpuStream_t stream) { + // shape check + if (FLAGS_enable_nccl_dynamic_check) { + phi::distributed::NCCLDynamicCheck::CheckGatherShape( + in_tensor, gather_tensors, opts.root_rank, rank_, size_, comm); + } + GroupStart(); + // root receive from all devices + if (rank_ == opts.root_rank) { + for (auto i = 0; i < size_; i++) { + auto& gather_tensor = gather_tensors[i]; + NCCL_CHECK( + phi::dynload::ncclRecv(gather_tensor.data(), + gather_tensor.numel(), + phi::ToNCCLDataType(gather_tensor.dtype()), + i, + comm, + stream)); + } + } + // send to root + NCCL_CHECK(phi::dynload::ncclSend(in_tensor.data(), + in_tensor.numel(), + phi::ToNCCLDataType(in_tensor.dtype()), + opts.root_rank, + comm, + stream)); + GroupEnd(); + }; + return RunFnInNCCLEnv( + gather_func, in_tensor, CommType::GATHER, sync_op, use_calc_stream); +} + std::shared_ptr ProcessGroupNCCL::Recv( phi::DenseTensor* tensor, int src_rank, diff --git a/paddle/fluid/distributed/collective/process_group_nccl.h b/paddle/fluid/distributed/collective/process_group_nccl.h index c5d3842a096d609ae1d04af685eb08e6284f0733..d4a159a7f4550d196b5d80805575e04eca83ec4f 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.h +++ b/paddle/fluid/distributed/collective/process_group_nccl.h @@ -136,6 +136,19 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { bool sync_op, bool use_calc_stream) override; + std::shared_ptr Gather(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const GatherOptions& opts, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr Gather( + std::vector* gather_tensors_ptr, + const phi::DenseTensor& in_tensor, + const GatherOptions& opts, + bool sync_op, + bool use_calc_stream) override; + std::shared_ptr Recv(phi::DenseTensor* tensor, int src_rank, int64_t offset, diff --git a/paddle/fluid/distributed/collective/types.h b/paddle/fluid/distributed/collective/types.h index 5dfb611821c82c32eb3089df25fde0703c8573cb..3bafa53727c7217c3cd0a0512ca3a679cb7ca134 100644 --- a/paddle/fluid/distributed/collective/types.h +++ b/paddle/fluid/distributed/collective/types.h @@ -48,6 +48,10 @@ struct ScatterOptions { int root_rank = 0; }; +struct GatherOptions { + int root_rank = 0; +}; + struct ReduceScatterOptions { ReduceOp reduce_op = ReduceOp::SUM; }; diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 3c201a7cdfbe12b622b2cd8eae368ca70197bdb7..f3639ced91f1745879c13fd6f6ff5c1d4e49177a 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -113,6 +113,10 @@ void BindDistributed(py::module *m) { .def_readwrite("reduce_op", &distributed::ReduceOptions::reduce_op) .def_readwrite("source_root", &distributed::ReduceOptions::root_rank); + py::class_(*m, "GatherOptions") + .def(py::init<>()) + .def_readwrite("root_rank", &distributed::GatherOptions::root_rank); + auto ProcessGroup = py::class_>(*m, "ProcessGroup") @@ -521,7 +525,44 @@ void BindDistributed(py::module *m) { py::arg("src"), py::arg("sync_op"), py::call_guard()) + .def( + "gather", + [](distributed::ProcessGroup &self, + py::handle py_in_tensor, + py::handle py_gather_tensor_list, + int dst, + bool sync_op, + bool use_calc_stream) { + auto out_tensor_list = + CastPyArg2VectorOfTensor(py_gather_tensor_list.ptr(), 0); + Tensor stack_out_tensor = paddle::stack(out_tensor_list, 0); + auto p_out_tensor = std::dynamic_pointer_cast( + stack_out_tensor.impl()); + auto *out_dense = p_out_tensor.get(); + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto p_in_tensor = std::dynamic_pointer_cast( + in_tensor.impl()); + auto in_dense = *p_in_tensor; + + auto *dev_ctx = + self.GetDeviceContext(in_tensor.place(), use_calc_stream); + distributed::GatherOptions gather_ops{dst}; + auto task = self.Gather( + out_dense, in_dense, gather_ops, sync_op, use_calc_stream); + SplitTensor(*dev_ctx, *out_dense, &out_tensor_list); + if (!use_calc_stream) { + // calculate stream will wait comm stream + task->UpdateWaitChain(*dev_ctx); + } + return task; + }, + py::arg("in"), + py::arg("out"), + py::arg("dst"), + py::arg("sync_op"), + py::arg("use_calc_stream"), + py::call_guard()) .def( "barrier", [](distributed::ProcessGroup &self, int8_t device_id) { diff --git a/paddle/phi/core/distributed/check/nccl_dynamic_check.cc b/paddle/phi/core/distributed/check/nccl_dynamic_check.cc index 6cb4c8cfe17519ee98e9ccd10a09dd7b16982496..da8fb5d98a82fb5f14a64ae480f4a7b75c91a27a 100644 --- a/paddle/phi/core/distributed/check/nccl_dynamic_check.cc +++ b/paddle/phi/core/distributed/check/nccl_dynamic_check.cc @@ -64,7 +64,7 @@ void NCCLDynamicCheck::CheckDataType(const phi::DenseTensor& tensor, PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclBroadcast(dtype_device, dtype_device, - kSize, + 1, ncclInt64, root_rank, comm, @@ -106,7 +106,7 @@ void NCCLDynamicCheck::CheckShape(const phi::DenseTensor& tensor, PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclBroadcast(shape_device, shape_device, - kSize, + 1, ncclInt64, root_rank, comm, @@ -141,10 +141,9 @@ void NCCLDynamicCheck::CheckShape(const phi::DenseTensor& out_tensor, PADDLE_ENFORCE_GPU_SUCCESS(gpuMalloc(&in_shape_device, kSize)); PADDLE_ENFORCE_GPU_SUCCESS(gpuMemcpy( in_shape_device, &in_shape_host, kSize, gpuMemcpyHostToDevice)); - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclReduce(in_shape_device, in_shape_device, - kSize, + 1, ncclInt64, ncclSum, rank, @@ -159,5 +158,42 @@ void NCCLDynamicCheck::CheckShape(const phi::DenseTensor& out_tensor, PADDLE_ENFORCE_GPU_SUCCESS(gpuFree(in_shape_device)); } } + +void NCCLDynamicCheck::CheckGatherShape( + const phi::DenseTensor& in_tensor, + const std::vector& out_tensors, + int root_rank, + int cur_rank, + int world_size, + ncclComm_t comm) { + std::vector shapes(world_size, 0); + shapes[cur_rank] = in_tensor.numel(); + int64_t* in_shape_device; + PADDLE_ENFORCE_GPU_SUCCESS( + gpuMalloc(&in_shape_device, world_size * sizeof(int64_t))); + PADDLE_ENFORCE_GPU_SUCCESS(gpuMemcpy(in_shape_device, + shapes.data(), + world_size * sizeof(int64_t), + gpuMemcpyHostToDevice)); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce(in_shape_device, + in_shape_device, + world_size, + ncclInt64, + ncclSum, + comm, + kDefaultStream)); + PADDLE_ENFORCE_GPU_SUCCESS(gpuMemcpy(shapes.data(), + in_shape_device, + world_size * sizeof(int64_t), + gpuMemcpyDeviceToHost)); + PADDLE_ENFORCE_GPU_SUCCESS(gpuFree(in_shape_device)); + + if (cur_rank == root_rank) { + for (int i = 0; i < world_size; i++) { + CheckShape(out_tensors[i], shapes[i]); + } + } +} } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/check/nccl_dynamic_check.h b/paddle/phi/core/distributed/check/nccl_dynamic_check.h index 64c13e2a760e591da3fd29f2b79ef0c2b3b5a105..23e8386d6f2aff399b573e8c8e3284e6ddf4b191 100644 --- a/paddle/phi/core/distributed/check/nccl_dynamic_check.h +++ b/paddle/phi/core/distributed/check/nccl_dynamic_check.h @@ -52,6 +52,14 @@ struct NCCLDynamicCheck { int world_size, ncclComm_t comm); + // can be used to check gather and all gather + static void CheckGatherShape(const phi::DenseTensor& in_tensor, + const std::vector& out_tensors, + int root_rank, + int cur_rank, + int world_size, + ncclComm_t comm); + private: // `0` represents default stream for both cuda & hip static constexpr gpuStream_t kDefaultStream = 0; diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index 74d3eaaa08f8d5e90c1de793eeda003ffc9dc447..e86b1bc32ec6f2de30832f0ffce992e058377b35 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -46,6 +46,7 @@ from .communication import ( reduce, send, scatter, + gather, scatter_object_list, isend, recv, @@ -82,6 +83,7 @@ __all__ = [ # noqa "spawn", "launch", "scatter", + "gather", "scatter_object_list", "broadcast", "broadcast_object_list", diff --git a/python/paddle/distributed/communication/__init__.py b/python/paddle/distributed/communication/__init__.py index 1d21e8103353c009ec0975f9593248a0cb97202d..39ee19c9168ab8c39b25e81ebedcb00b4b92f15b 100644 --- a/python/paddle/distributed/communication/__init__.py +++ b/python/paddle/distributed/communication/__init__.py @@ -18,6 +18,7 @@ from .reduce import reduce, ReduceOp from .send import send, isend from .recv import recv, irecv from .scatter import scatter, scatter_object_list +from .gather import gather from .batch_isend_irecv import batch_isend_irecv, P2POp from .reduce_scatter import reduce_scatter from .all_to_all import alltoall, alltoall_single diff --git a/python/paddle/distributed/communication/gather.py b/python/paddle/distributed/communication/gather.py new file mode 100644 index 0000000000000000000000000000000000000000..089f73d1d0c7eae36465bccf9e239f8e5e6836e9 --- /dev/null +++ b/python/paddle/distributed/communication/gather.py @@ -0,0 +1,60 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from paddle import framework +from paddle.distributed.communication import stream + + +def gather(tensor, gather_list=None, dst=0, group=None, sync_op=True): + """ + + Gather tensors from all participators. + + Args: + tensor (Tensor): The input Tensor. Its data type + should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. + gather_list (list): A list of Tensors to hold the gathered tensors. Every element in the list must be a Tensor whose data type + should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. Default value is None. + dst (int): The dst rank id. Default value is 0. + group (Group, optional): The group instance return by new_group or None for global default group. + sync_op (bool, optional): Whether this op is a sync op. The default value is True. + + Returns: + Async work handle,which can be wait on, if async_op is set to True. + None, if not async_op + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + gather_list = [] + if dist.get_rank() == 0: + data = paddle.to_tensor([1, 2, 3]) + dist.gather(data, gather_list, dst=0) + else: + data = paddle.to_tensor([4, 5, 6]) + dist.gather(data1, gather_list, dst=0) + print(gather_list) + # [[1, 2, 3], [4, 5, 6]] (2 GPUs, out for rank 0) + # [] (2 GPUs, out for rank 1) + """ + assert ( + framework.in_dygraph_mode() + ), "gather doesn't support static graph mode yet." + return stream.gather(tensor, gather_list, dst, group, sync_op) diff --git a/python/paddle/distributed/communication/stream/__init__.py b/python/paddle/distributed/communication/stream/__init__.py index 423b655c6284d389399b885a72f16d1e1bc27e96..77c02e33eeae2fd52ac74db15e5bff5f58415078 100644 --- a/python/paddle/distributed/communication/stream/__init__.py +++ b/python/paddle/distributed/communication/stream/__init__.py @@ -21,6 +21,7 @@ from .reduce_scatter import reduce_scatter from .recv import recv from .scatter import scatter from .send import send +from .gather import gather __all__ = [ "all_gather", @@ -33,4 +34,5 @@ __all__ = [ "recv", "scatter", "send", + "gather", ] diff --git a/python/paddle/distributed/communication/stream/gather.py b/python/paddle/distributed/communication/stream/gather.py new file mode 100644 index 0000000000000000000000000000000000000000..df3db07eb59b11951c5394d96259c724d44597ce --- /dev/null +++ b/python/paddle/distributed/communication/stream/gather.py @@ -0,0 +1,132 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import paddle +import paddle.distributed as dist +from paddle import framework +from paddle.distributed.communication.group import ( + _get_global_group, + _get_or_throw_group_rank, + _warn_cur_rank_not_in_group, +) + + +def _gather_in_dygraph( + tensor, gather_list, dst_rank_in_group, group, sync_op, use_calc_stream +): + nranks = group.nranks + if group.rank == dst_rank_in_group: + if len(gather_list) == 0: + gather_list += [paddle.empty_like(tensor) for _ in range(nranks)] + else: + gather_list = [tensor for _ in range(nranks)] + + assert ( + len(gather_list) == nranks + ), " gather_list length {} and nrankd {} not equal".format( + len(gather_list), nranks + ) + + task = group.process_group.gather( + tensor, gather_list, dst_rank_in_group, sync_op, use_calc_stream + ) + + if sync_op: + task.wait() + + return task + + +def gather( + tensor, + gather_list=None, + dst=0, + group=None, + sync_op=True, + use_calc_stream=False, +): + + """ + + Gather tensors from all participators. + + Args: + tensor (Tensor): The input Tensor. Its data type + should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. + gather_list (list): A list of Tensors to hold the gathered tensors. Every element in the list must be a Tensor whose data type + should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. Default value is None. + dst (int): The dst rank id. Default value is 0. + group (Group, optional): The group instance return by new_group or None for global default group. + sync_op (bool, optional): Whether this op is a sync op. The default value is True. + use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This + option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning. + + Returns: + Async work handle,which can be wait on, if async_op is set to True. + None, if not async_op + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + gather_list = [] + if dist.get_rank() == 0: + data = paddle.to_tensor([1, 2, 3]) + dist.stream.gather(data, gather_list, dst=0) + else: + data = paddle.to_tensor([4, 5, 6]) + dist.stream.gather(data1, gather_list, dst=0) + print(gather_list) + # [[1, 2, 3], [4, 5, 6]] (2 GPUs, out for rank 0) + # [] (2 GPUs, out for rank 1) + """ + + assert ( + framework.in_dygraph_mode() + ), "gather doesn't support static graph mode yet." + + if _warn_cur_rank_not_in_group(group): + return + + if not sync_op and use_calc_stream: + raise RuntimeError( + "use_calc_stream can only be true in sync op behavior." + ) + + # NOTE(liuzhenhai): Only the dst rank needs to specific the gather_list argument. + # Other ranks which pass this argument in will be ignored with a warning. + # The passed in type for non-dst rank is meaningless, for it will be ignored. + if dst != dist.get_rank(): + if gather_list is not None: + warnings.warn( + "Specific `gather_list` is meaningless for rank which is not dst." + ) + gather_list = [] + else: + + assert ( + gather_list is not None + ), "gather_list must not be none for dst rank" + + group = _get_global_group() if group is None else group + dst_rank_in_group = _get_or_throw_group_rank(dst, group) + return _gather_in_dygraph( + tensor, gather_list, dst_rank_in_group, group, sync_op, use_calc_stream + ) diff --git a/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt b/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt index 792daf1479896c9bd7818c8f7e800fa2ac704222..b7d9c54e86a34256b5efbd88aa73c93b590a2305 100644 --- a/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt @@ -71,7 +71,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_collective_allreduce_api MODULES test_collective_allreduce_api ENVS "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") set_tests_properties(test_collective_allreduce_api - PROPERTIES TIMEOUT "250" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( @@ -195,7 +195,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_collective_reduce_api MODULES test_collective_reduce_api ENVS "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") set_tests_properties(test_collective_reduce_api - PROPERTIES TIMEOUT "230" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "500" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) bash_test_modules( @@ -239,6 +239,13 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) set_tests_properties(test_collective_scatter_object_list_api PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") endif() +if((WITH_GPU OR WITH_ROCM) AND (LINUX)) + py_test_modules( + test_collective_gather_api MODULES test_collective_gather_api ENVS + "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") + set_tests_properties(test_collective_gather_api + PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST") +endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( test_collective_sendrecv MODULES test_collective_sendrecv ENVS @@ -251,7 +258,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_collective_sendrecv_api MODULES test_collective_sendrecv_api ENVS "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") set_tests_properties(test_collective_sendrecv_api - PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( diff --git a/python/paddle/fluid/tests/unittests/collective/README.md b/python/paddle/fluid/tests/unittests/collective/README.md index 6f89b19ad8657418f4967cd19ccebe563209a251..a06cf6dc39640cf2544f510ee5d3184d105fd358 100644 --- a/python/paddle/fluid/tests/unittests/collective/README.md +++ b/python/paddle/fluid/tests/unittests/collective/README.md @@ -21,7 +21,7 @@ ```bash python3 ${PADDLE_ROOT}/tools/gen_ut_cmakelists.py -f ${PADDLE_ROOT}/python/paddle/fluid/tests/unittests/collective/testslist.csv ``` - Then the cmd generates a file named CMakeLists.txt in the save directory with the testslist.csv. + Then the cmd generates a file named CMakeLists.txt in the same directory with the testslist.csv. * usgae: The command accepts --files/-f or --dirpaths/-d options, both of which accepts multiple values. Option -f accepts a list of testslist.csv. diff --git a/python/paddle/fluid/tests/unittests/collective/collective_gather_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/collective_gather_api_dygraph.py new file mode 100644 index 0000000000000000000000000000000000000000..063d9043aa2dff44d0f9bd16250e74e7c16e9704 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/collective_gather_api_dygraph.py @@ -0,0 +1,41 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import test_collective_api_base as test_base + +import paddle +import paddle.distributed as dist +from paddle import fluid + + +class TestCollectiveGatherAPI(test_base.TestCollectiveAPIRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank, indata=None): + with fluid.program_guard(main_prog, startup_program): + gather_list = [] + # NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16 + if indata.dtype == "bfloat16": + tindata = paddle.to_tensor(indata, "float32").cast("uint16") + dist.gather(tindata, gather_list, dst=0) + return [e.cast("float32").numpy() for e in gather_list] + else: + tindata = paddle.to_tensor(indata) + dist.gather(tindata, gather_list, dst=0) + return [e.numpy() for e in gather_list] + + +if __name__ == "__main__": + test_base.runtime_main(TestCollectiveGatherAPI, "gather") diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt b/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt index bdb789c813be55714564f76e0b3ab458d4742d2c..90be9cba14f0a424857230905d1b19ea04f82fe3 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt @@ -2,8 +2,8 @@ # Please don't modify this file manually. # If you need to change unittests in this file, please modify testslist.csv in the current directory # and then run the command `python3 ${PADDLE_ROOT}/tools/gen_ut_cmakelists.py -f ${CURRENT_DIRECTORY}/testslist.csv` -set(LOCAL_ALL_PLAT ON) set(LOCAL_ALL_ARCH ON) +set(LOCAL_ALL_PLAT ON) if((WITH_GPU OR WITH_XPU OR WITH_ASCEND @@ -650,6 +650,18 @@ if((WITH_GPU "PADDLE_DIST_UT_PORT=21270;http_proxy=;https_proxy=") set_tests_properties(test_c_comm_init_op PROPERTIES TIMEOUT "120") endif() +if((WITH_GPU) AND (LINUX)) + bash_test_modules( + test_fused_attention_pass_with_mp + START_BASH + test_fused_attention_pass_with_mp.sh + LABELS + "RUN_TYPE=DIST" + ENVS + "PADDLE_DIST_UT_PORT=21400;http_proxy=;https_proxy=") + set_tests_properties(test_fused_attention_pass_with_mp PROPERTIES TIMEOUT + "120") +endif() if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) bash_test_modules( test_ir_pass_pipeline @@ -882,15 +894,3 @@ if((WITH_GPU) AND (LINUX)) set_tests_properties(test_dygraph_save_for_auto_infer PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") endif() -if(WITH_GPU) - bash_test_modules( - test_fused_attention_pass_with_mp - START_BASH - test_fused_attention_pass_with_mp.sh - LABELS - "RUN_TYPE=DIST" - ENVS - "PADDLE_DIST_UT_PORT=21400;http_proxy=;https_proxy=") - set_tests_properties(test_fused_attention_pass_with_mp PROPERTIES TIMEOUT - "120") -endif() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv b/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv index 2da2aa06ca709b010ed099198d16b25aa2b655a9..9ac8fdf65729eb21dbc83a1580221e17f47a6be7 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv +++ b/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv @@ -56,7 +56,7 @@ test_fleet_recompute_meta_optimizer,LINUX;WIN32,GPU;XPU;ASCEND;ASCEND_CL,,,test_ test_fleet_private_function,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_new_group,,GPU;XPU;ASCEND;ASCEND_CL,,DIST,test_new_group.sh,2,,http_proxy=;https_proxy=, test_c_comm_init_op,LINUX,GPU;XPU;ASCEND;ASCEND_CL,120,DIST,test_c_comm_init_op.sh,2,,http_proxy=;https_proxy=, -test_fused_attention_pass_with_mp,LINUX,GPU;;;,120,DIST,test_fused_attention_pass_with_mp.sh,2,,http_proxy=;https_proxy=, +test_fused_attention_pass_with_mp,LINUX,GPU,120,DIST,test_fused_attention_pass_with_mp.sh,2,,http_proxy=;https_proxy=, test_ir_pass_pipeline,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_parallel_dygraph_mnist,,GPU;ROCM,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_parallel_dygraph_se_resnext,,GPU;ROCM,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., diff --git a/python/paddle/fluid/tests/unittests/collective/test_collective_gather_api.py b/python/paddle/fluid/tests/unittests/collective/test_collective_gather_api.py new file mode 100644 index 0000000000000000000000000000000000000000..68b6d12878fb6d456e66afd573de47896eb07ca3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/test_collective_gather_api.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_collective_api_base import TestDistBase + +import paddle + +paddle.enable_static() + + +class TestCollectiveGatherAPI(TestDistBase): + def _setup_config(self): + pass + + def test_gather_nccl_dygraph(self): + dtypes_to_test = [ + "float16", + "float32", + "float64", + "int32", + "int64", + "int8", + "uint8", + "bool", + ] + if self._nccl_version >= 2100: + dtypes_to_test.append("bfloat16") + for dtype in dtypes_to_test: + self.check_with_place( + "collective_gather_api_dygraph.py", + "gather", + "nccl", + static_mode="0", + dtype=dtype, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/testslist.csv b/python/paddle/fluid/tests/unittests/collective/testslist.csv index 6c02e39c422318b815f04a4ada43d38761dfaa05..cf2b6c6757b3071e0a1316ac2f2d7381c9c137b0 100644 --- a/python/paddle/fluid/tests/unittests/collective/testslist.csv +++ b/python/paddle/fluid/tests/unittests/collective/testslist.csv @@ -7,13 +7,13 @@ test_c_split,linux,gpu;rocm,120,DIST,test_runner.py,2,,PYTHONPATH=..;http_proxy= test_collective_split_embedding,linux,rocm;gpu,300,DIST,../dist_test.sh,2,,PYTHONPATH=..;http_proxy=;https_proxy=, test_collective_allgather_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_allgather_object_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., -test_collective_allreduce_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., +test_collective_allreduce_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_alltoall_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_alltoall_single,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_alltoall_single_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_barrier_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_batch_isend_irecv,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=.., -test_collective_broadcast_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., +test_collective_broadcast_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_broadcast_object_list_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_cpu_barrier_with_gloo,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_global_gather,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., @@ -22,14 +22,15 @@ test_collective_isend_irecv_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_p test_collective_optimizer,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_process_group,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_reduce,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., -test_collective_reduce_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., +test_collective_reduce_api,linux,gpu;rocm,500,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_reduce_scatter,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_reduce_scatter_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_scatter,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_scatter_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_scatter_object_list_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., +test_collective_gather_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_sendrecv,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., -test_collective_sendrecv_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., +test_collective_sendrecv_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_split_col_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_split_embedding_none_divisible,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_split_row_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., diff --git a/python/paddle/fluid/tests/unittests/test_collective_api_base.py b/python/paddle/fluid/tests/unittests/test_collective_api_base.py index 550b0e24e487cd89588720a11f2e0c12a700be1b..d53231bdc31d4bc3038ae5510bb6326f299d5fbf 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_api_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_api_base.py @@ -392,6 +392,14 @@ class TestDistBase(unittest.TestCase): need_result2 = [need_result[len(need_result) // 2 :]] self.assertEqual(need_result1, tr0_out) self.assertEqual(need_result2, tr1_out) + elif col_type == "gather": + # rank 0 gather all tensor + self.assertEqual(len(tr0_out), 2) + # rank 1 get nothing + self.assertEqual(len(tr1_out), 0) + # check values + np.testing.assert_equal(input1, tr0_out[0]) + np.testing.assert_equal(input2, tr0_out[1]) elif col_type == "reduce_scatter": need_result = input1 + input2 need_result1 = need_result[0 : need_result.shape[0] // 2]