From b565b349752d0917fd5ca3f118ad1c618a098db9 Mon Sep 17 00:00:00 2001 From: lilong12 Date: Thu, 3 Mar 2022 11:44:59 +0800 Subject: [PATCH] add communication api for ProcessGroupNCCL (#40097) --- .../distributed/collective/ProcessGroup.h | 29 ++++ .../collective/ProcessGroupNCCL.cc | 143 ++++++++++++++++++ .../distributed/collective/ProcessGroupNCCL.h | 14 ++ paddle/fluid/distributed/collective/Types.h | 9 ++ paddle/fluid/pybind/distributed_py.cc | 57 +++++++ .../tests/unittests/process_group_nccl.py | 100 +++++++++++- 6 files changed, 345 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index e4f2720520..e43d0e8c18 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -117,6 +117,35 @@ class ProcessGroup { "ProcessGroup%s does not support receive", GetBackendName())); } + virtual std::shared_ptr AllGather( + std::vector& in_tensors /* tensors */, // NOLINT + std::vector& out_tensors /* tensors */) { // NOLINT + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support AllGather", GetBackendName())); + } + + virtual std::shared_ptr AllToAll( + std::vector& in /* tensors */, // NOLINT + std::vector& out /* tensors */) { // NOLINT + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support AllToAll", GetBackendName())); + } + + virtual std::shared_ptr Reduce( + std::vector& tensors /* tensors */, // NOLINT + const ReduceOptions& opts) { // NOLINT + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support Reduce", GetBackendName())); + } + + virtual std::shared_ptr Scatter( + std::vector& in_tensors /* tensors */, // NOLINT + std::vector& out_tensors /* tensors */, // NOLINT + const ScatterOptions&) { // NOLINT + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support Scatter", GetBackendName())); + } + protected: const int rank_; const int size_; diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 5d96e730aa..88d8fb69eb 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -473,5 +473,148 @@ std::shared_ptr ProcessGroupNCCL::Recv( return task; } +std::shared_ptr ProcessGroupNCCL::AllGather( + std::vector& in_tensors, std::vector& out_tensors) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(in_tensors), true, + platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(out_tensors), true, + platform::errors::InvalidArgument("All outputs should be in CudaPlace.")); + return Collective( + in_tensors, out_tensors, + [&](const Tensor& input, Tensor& output, ncclComm_t comm, + const gpuStream_t& stream) { + auto input_tensor = + std::dynamic_pointer_cast(input.impl()); + auto output_tensor = + std::dynamic_pointer_cast(output.impl()); + return platform::dynload::ncclAllGather( + input_tensor->data(), output_tensor->data(), input_tensor->numel(), + platform::ToNCCLDataType(input.type()), comm, stream); + }, + CommType::ALLGATHER); +} + +void* GetPointerByOffset(void* raw_pointer, size_t offset, + experimental::DataType type) { + if (type == experimental::DataType::FLOAT32) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == experimental::DataType::FLOAT64) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == experimental::DataType::INT32) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == experimental::DataType::INT64) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else if (type == experimental::DataType::FLOAT16) { + return reinterpret_cast(reinterpret_cast(raw_pointer) + + offset); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "This datatype in nccl is not supported.")); + } +} + +std::shared_ptr ProcessGroupNCCL::AllToAll( + std::vector& in_tensors, std::vector& out_tensors) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(in_tensors), true, + platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(out_tensors), true, + platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + return Collective( + in_tensors, out_tensors, + [&](const Tensor& input, Tensor& output, ncclComm_t comm, + const gpuStream_t& stream) { + auto input_tensor = + std::dynamic_pointer_cast(input.impl()); + auto output_tensor = + std::dynamic_pointer_cast(output.impl()); + size_t offset = 0; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto i = 0; i < size_; i++) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( + GetPointerByOffset(input_tensor->data(), offset, input.type()), + input_tensor->numel() / size_, + platform::ToNCCLDataType(input.type()), i, comm, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( + GetPointerByOffset(output_tensor->data(), offset, input.type()), + input_tensor->numel() / size_, + platform::ToNCCLDataType(input.type()), i, comm, stream)); + offset += input_tensor->numel() / size_; + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + }, + CommType::ALLREDUCE); +} + +std::shared_ptr ProcessGroupNCCL::Reduce( + std::vector& tensors, const ReduceOptions& opts) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(tensors), true, + platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + return Collective( + tensors, tensors, + [&](const Tensor& input, Tensor& output, ncclComm_t comm, + const gpuStream_t& stream) { + auto input_tensor = + std::dynamic_pointer_cast(input.impl()); + auto output_tensor = + std::dynamic_pointer_cast(output.impl()); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce( + input_tensor->data(), output_tensor->data(), input.numel(), + platform::ToNCCLDataType(input.type()), + ToNCCLRedType(opts.reduce_op), opts.root_rank, comm, stream)); + }, + CommType::REDUCE); +} + +std::shared_ptr ProcessGroupNCCL::Scatter( + std::vector& in_tensors, std::vector& out_tensors, + const ScatterOptions& opts) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(in_tensors), true, + platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(out_tensors), true, + platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + return Collective( + in_tensors, out_tensors, + [&](const Tensor& input, Tensor& output, ncclComm_t comm, + const gpuStream_t& stream) { + auto input_tensor = + std::dynamic_pointer_cast(input.impl()); + auto output_tensor = + std::dynamic_pointer_cast(output.impl()); + size_t offset = 0; + if (rank_ == opts.root_rank) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto i = 0; i < size_; i++) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( + GetPointerByOffset(input_tensor->data(), offset, input.type()), + input_tensor->numel() / size_, + platform::ToNCCLDataType(input.type()), i, comm, stream)); + offset += input_tensor->numel() / size_; + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( + output_tensor->data(), input_tensor->numel() / size_, + platform::ToNCCLDataType(input.type()), opts.root_rank, comm, + stream)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( + output_tensor->data(), input_tensor->numel() / size_, + platform::ToNCCLDataType(input.type()), opts.root_rank, comm, + stream)); + } + }, + CommType::SCATTER); +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index cfeb6467f0..d63a5e7683 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -98,6 +98,20 @@ class ProcessGroupNCCL : public ProcessGroup { std::shared_ptr Recv(std::vector& tensors, int src_rank) override; + std::shared_ptr AllGather( + std::vector& in_tensors, + std::vector& out_tensors) override; + + std::shared_ptr AllToAll( + std::vector& in, std::vector& out) override; + + std::shared_ptr Reduce( + std::vector& tensors, const ReduceOptions& opts) override; + + std::shared_ptr Scatter(std::vector& in_tensors, + std::vector& out_tensors, + const ScatterOptions&) override; + protected: virtual std::shared_ptr CreateTask( std::vector places, int rank, CommType opType, diff --git a/paddle/fluid/distributed/collective/Types.h b/paddle/fluid/distributed/collective/Types.h index 699222ac45..973f7c6435 100644 --- a/paddle/fluid/distributed/collective/Types.h +++ b/paddle/fluid/distributed/collective/Types.h @@ -36,5 +36,14 @@ struct BarrierOptions { std::vector place_ids; }; +struct ReduceOptions { + ReduceOp reduce_op = ReduceOp::SUM; + int root_rank = 0; +}; + +struct ScatterOptions { + int root_rank = 0; +}; + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 3b5644764a..1751286335 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -77,6 +77,11 @@ void BindDistributed(py::module *m) { .def(py::init<>()) .def_readwrite("place_ids", &distributed::BarrierOptions::place_ids); + py::class_(*m, "ReduceOptions") + .def(py::init<>()) + .def_readwrite("reduce_op", &distributed::ReduceOptions::reduce_op) + .def_readwrite("source_root", &distributed::ReduceOptions::root_rank); + auto ProcessGroup = py::class_>(*m, "ProcessGroup") @@ -134,6 +139,58 @@ void BindDistributed(py::module *m) { return self.Recv(tensors, src); }, py::arg("tensor"), py::arg("src"), + py::call_guard()) + + .def("all_gather", + [](distributed::ProcessGroup &self, py::handle py_in_tensor, + py::handle py_out_tensor) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + std::vector in_tensors = {in_tensor}; + std::vector out_tensors = {out_tensor}; + return self.AllGather(in_tensors, out_tensors); + }, + py::arg("in"), py::arg("out"), + py::call_guard()) + + .def("alltoall", + [](distributed::ProcessGroup &self, py::handle py_in_tensor, + py::handle py_out_tensor) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + std::vector in_tensors = {in_tensor}; + std::vector out_tensors = {out_tensor}; + return self.AllToAll(in_tensors, out_tensors); + }, + py::arg("in"), py::arg("out"), + py::call_guard()) + + .def("reduce", + [](distributed::ProcessGroup &self, py::handle py_in_tensor, + int dst, distributed::ReduceOp op) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + distributed::ReduceOptions opts; + opts.reduce_op = op; + opts.root_rank = dst; + std::vector tensors = {in_tensor}; + return self.Reduce(tensors, opts); + }, + py::arg("tensor"), py::arg("dst"), + py::arg("op") = distributed::ReduceOp::SUM, + py::call_guard()) + + .def("scatter", + [](distributed::ProcessGroup &self, py::handle py_in_tensor, + py::handle py_out_tensor, int src) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + distributed::ScatterOptions opts; + opts.root_rank = src; + std::vector in_tensors = {in_tensor}; + std::vector out_tensors = {out_tensor}; + return self.Scatter(in_tensors, out_tensors, opts); + }, + py::arg("in"), py::arg("out"), py::arg("src"), py::call_guard()); #if defined(PADDLE_WITH_NCCL) diff --git a/python/paddle/fluid/tests/unittests/process_group_nccl.py b/python/paddle/fluid/tests/unittests/process_group_nccl.py index 8ec5d13c56..4833cea9a8 100644 --- a/python/paddle/fluid/tests/unittests/process_group_nccl.py +++ b/python/paddle/fluid/tests/unittests/process_group_nccl.py @@ -144,23 +144,109 @@ class TestProcessGroupFp32(unittest.TestCase): print("test barrier api ok\n") - # test send/recv + # test allgather # rank 0 x = np.random.random(self.shape).astype(self.dtype) + y = np.random.random(self.shape).astype(self.dtype) + tensor_x = paddle.to_tensor(x) + tensor_y = paddle.to_tensor(y) + out_shape = list(self.shape) + out_shape[0] *= 2 + out = np.random.random(out_shape).astype(self.dtype) + tensor_out = paddle.to_tensor(out) + if pg.rank() == 0: + task = pg.all_gather(tensor_x, tensor_out) + task.wait() + paddle.device.cuda.synchronize() + # rank 1 + else: + task = pg.all_gather(tensor_y, tensor_out) + task.wait() + paddle.device.cuda.synchronize() + out_1 = paddle.slice(tensor_out, [0], [0], [out_shape[0] // 2]) + out_2 = paddle.slice(tensor_out, [0], [out_shape[0] // 2], + [out_shape[0]]) + assert np.array_equal(tensor_x, out_1) + assert np.array_equal(tensor_y, out_2) + print("test allgather api ok\n") + + # test alltoall + # rank 0 + x = np.random.random(self.shape).astype(self.dtype) + y = np.random.random(self.shape).astype(self.dtype) + out1 = np.random.random(self.shape).astype(self.dtype) + out2 = np.random.random(self.shape).astype(self.dtype) tensor_x = paddle.to_tensor(x) + tensor_y = paddle.to_tensor(y) + tensor_out1 = paddle.to_tensor(out1) + tensor_out2 = paddle.to_tensor(out2) + raw_tensor_x_2 = paddle.slice(tensor_x, [0], [self.shape[0] // 2], + [self.shape[0]]) + raw_tensor_y_1 = paddle.slice(tensor_y, [0], [0], + [self.shape[0] // 2]) if pg.rank() == 0: - task = pg.send(tensor_x, dst=1) + task = pg.alltoall(tensor_x, tensor_out1) task.wait() paddle.device.cuda.synchronize() # rank 1 else: - y = np.random.random(self.shape).astype(self.dtype) - tensor_y = paddle.to_tensor(y) - task = pg.recv(tensor_y, src=0) + task = pg.alltoall(tensor_y, tensor_out2) task.wait() paddle.device.cuda.synchronize() - assert np.array_equal(tensor_x, tensor_y) - print("test send/recv api ok\n") + out1_2 = paddle.slice(tensor_out1, [0], [self.shape[0] // 2], + [self.shape[0]]) + out2_1 = paddle.slice(tensor_out2, [0], [0], [self.shape[0] // 2]) + if pg.rank() == 0: + assert np.array_equal(out1_2.numpy(), raw_tensor_y_1.numpy()) + else: + assert np.array_equal(out2_1, raw_tensor_x_2) + print("test alltoall api ok\n") + + # test Reduce + # rank 0 + x = np.random.random(self.shape).astype(self.dtype) + y = np.random.random(self.shape).astype(self.dtype) + tensor_x = paddle.to_tensor(x) + tensor_y = paddle.to_tensor(y) + sum_result = tensor_x + tensor_y + if pg.rank() == 0: + task = pg.reduce(tensor_x, 0) + task.wait() + paddle.device.cuda.synchronize() + # rank 1 + else: + task = pg.reduce(tensor_y, 0) + task.wait() + paddle.device.cuda.synchronize() + if pg.rank() == 0: + assert np.array_equal(tensor_x, sum_result) + print("test reduce sum api ok\n") + + # test Scatter + # rank 0 + in_shape = list(self.shape) + in_shape[0] *= 2 + x = np.random.random(in_shape).astype(self.dtype) + y = np.random.random(self.shape).astype(self.dtype) + tensor_x = paddle.to_tensor(x) + tensor_y = paddle.to_tensor(y) + if pg.rank() == 0: + task = pg.scatter(tensor_x, tensor_y, 0) + task.wait() + paddle.device.cuda.synchronize() + # rank 1 + else: + task = pg.scatter(tensor_x, tensor_y, 0) + task.wait() + paddle.device.cuda.synchronize() + out1 = paddle.slice(tensor_x, [0], [0], [self.shape[0]]) + out2 = paddle.slice(tensor_x, [0], [self.shape[0]], + [self.shape[0] * 2]) + if pg.rank() == 0: + assert np.array_equal(tensor_y, out1) + else: + assert np.array_equal(tensor_y, out2) + print("test scatter api ok\n") class TestProcessGroupFp16(TestProcessGroupFp32): -- GitLab