From 9ce31e965e53d499f6e3811c51861ca996b5debe Mon Sep 17 00:00:00 2001 From: wuhuachaocoding <77733235+wuhuachaocoding@users.noreply.github.com> Date: Wed, 21 Sep 2022 14:49:41 +0800 Subject: [PATCH] Mpi final dev simple (#46247) --- CMakeLists.txt | 7 + cmake/mpi.cmake | 33 ++ .../distributed/collective/CMakeLists.txt | 7 + .../fluid/distributed/collective/MPITools.cc | 56 ++ .../fluid/distributed/collective/MPITools.h | 53 ++ .../distributed/collective/ProcessGroup.cc | 8 + .../distributed/collective/ProcessGroup.h | 3 + .../distributed/collective/ProcessGroupMPI.cc | 467 ++++++++++++++++ .../distributed/collective/ProcessGroupMPI.h | 211 ++++++++ paddle/fluid/pybind/CMakeLists.txt | 7 + paddle/fluid/pybind/distributed_py.cc | 23 + paddle/fluid/pybind/pybind.cc | 19 + .../tests/unittests/collective/CMakeLists.txt | 12 + .../unittests/collective/process_group_mpi.py | 506 ++++++++++++++++++ .../unittests/collective/test_mpi_comm.sh | 27 + .../tests/unittests/collective/testslist.csv | 1 + 16 files changed, 1440 insertions(+) create mode 100644 cmake/mpi.cmake create mode 100644 paddle/fluid/distributed/collective/MPITools.cc create mode 100644 paddle/fluid/distributed/collective/MPITools.h create mode 100644 paddle/fluid/distributed/collective/ProcessGroupMPI.cc create mode 100644 paddle/fluid/distributed/collective/ProcessGroupMPI.h create mode 100644 python/paddle/fluid/tests/unittests/collective/process_group_mpi.py create mode 100644 python/paddle/fluid/tests/unittests/collective/test_mpi_comm.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 290310858fb..304852f71af 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -485,6 +485,9 @@ if(WITH_DISTRIBUTE) ON CACHE STRING "Enable GLOO when compiling WITH_DISTRIBUTE=ON." FORCE) endif() + set(WITH_MPI + ON + CACHE STRING "Enable MPI when compiling WITH_DISTRIBUTE=ON." FORCE) if(WITH_ASCEND_CL AND NOT WITH_ARM_BRPC) # disable WITH_PSCORE for NPU before include third_party message( @@ -509,6 +512,10 @@ if(WITH_DISTRIBUTE) endif() endif() +if(WITH_MPI) + include(mpi) +endif() + include(third_party )# download, build, install third_party, Contains about 20+ dependencies diff --git a/cmake/mpi.cmake b/cmake/mpi.cmake new file mode 100644 index 00000000000..650eaeceb76 --- /dev/null +++ b/cmake/mpi.cmake @@ -0,0 +1,33 @@ +if(NOT WITH_DISTRIBUTE OR NOT WITH_MPI) + return() +endif() + +find_package(MPI) + +if(NOT MPI_CXX_FOUND) + set(WITH_MPI + OFF + CACHE STRING "Disable MPI" FORCE) + message(WARNING "Not found MPI support in current system") + return() +endif() + +message(STATUS "MPI compile flags: " ${MPI_CXX_COMPILE_FLAGS}) +message(STATUS "MPI include path: " ${MPI_CXX_INCLUDE_PATH}) +message(STATUS "MPI LINK flags path: " ${MPI_CXX_LINK_FLAGS}) +message(STATUS "MPI libraries: " ${MPI_CXX_LIBRARIES}) +include_directories(SYSTEM ${MPI_CXX_INCLUDE_PATH}) +set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${MPI_CXX_LINK_FLAGS}") +add_definitions("-DPADDLE_WITH_MPI") +find_program( + OMPI_INFO + NAMES ompi_info + HINTS ${MPI_CXX_LIBRARIES}/../bin) + +if(OMPI_INFO) + execute_process(COMMAND ${OMPI_INFO} OUTPUT_VARIABLE output_) + if(output_ MATCHES "smcuda") + #NOTE some mpi lib support mpi cuda aware. + add_definitions("-DPADDLE_WITH_MPI_AWARE") + endif() +endif() diff --git a/paddle/fluid/distributed/collective/CMakeLists.txt b/paddle/fluid/distributed/collective/CMakeLists.txt index f7de900d696..7f6a5e262b7 100644 --- a/paddle/fluid/distributed/collective/CMakeLists.txt +++ b/paddle/fluid/distributed/collective/CMakeLists.txt @@ -43,6 +43,13 @@ if(WITH_NCCL OR WITH_RCCL) endif() endif() +if(WITH_MPI) + cc_library( + processgroup_mpi + SRCS ProcessGroupMPI.cc MPITools.cc Common.cc + DEPS collective_helper device_context) +endif() + if(WITH_ASCEND_CL) cc_library( processgroup_hccl diff --git a/paddle/fluid/distributed/collective/MPITools.cc b/paddle/fluid/distributed/collective/MPITools.cc new file mode 100644 index 00000000000..042169728db --- /dev/null +++ b/paddle/fluid/distributed/collective/MPITools.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/collective/MPITools.h" +#include "paddle/fluid/distributed/collective/Common.h" +#include "paddle/fluid/distributed/collective/Types.h" + +namespace paddle { +namespace distributed { +namespace mpi { + +MPI_Op ToMPIType(ReduceOp reduction) { + static const std::map red_type = { + {ReduceOp::MIN, MPI_MIN}, + {ReduceOp::MAX, MPI_MAX}, + {ReduceOp::SUM, MPI_SUM}, + {ReduceOp::PRODUCT, MPI_PROD}, + }; + auto it = red_type.find(reduction); + PADDLE_ENFORCE_EQ(it != red_type.end(), + true, + platform::errors::InvalidArgument( + "Invalid mpi reduction. Must be MPI_MIN | MPI_MAX | " + "MPI_PROD | MPI_SUM.")); + return it->second; +} + +// NOTE: MPI dose not support CUDA aware now. +bool CheckMpiCudaAware() { return false; } + +void CheckValidInputs(const std::vector& tensors) { + PADDLE_ENFORCE_EQ( + tensors.size() == 1, + true, + platform::errors::InvalidArgument("the inputs size of MPI must be 1!")); + + PADDLE_ENFORCE_EQ(CheckTensorsInCudaPlace(tensors) && !CheckMpiCudaAware(), + false, + platform::errors::InvalidArgument( + "Found CUDA Tensor. But CUDA-aware MPI not support!")); +} + +} // namespace mpi +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/collective/MPITools.h b/paddle/fluid/distributed/collective/MPITools.h new file mode 100644 index 00000000000..34fbb3ca9a5 --- /dev/null +++ b/paddle/fluid/distributed/collective/MPITools.h @@ -0,0 +1,53 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/platform/enforce.h" + +#include "paddle/fluid/distributed/collective/Types.h" + +#ifdef HOST +#undef HOST +#endif + +#include + +namespace paddle { +namespace distributed { +namespace mpi { + +#define MPI_CHECK(cmd) \ + do { \ + int r = cmd; \ + if (r != MPI_SUCCESS) { \ + LOG(FATAL) << "Failed, MPI error in" << __FILE__ << ":" << __LINE__ \ + << "with error code: " << std::to_string(r) << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +MPI_Op ToMPIType(ReduceOp reduction); + +bool CheckMpiCudaAware(); + +void CheckValidInputs(const std::vector& tensors); + +} // namespace mpi +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroup.cc b/paddle/fluid/distributed/collective/ProcessGroup.cc index 925d8e771cb..1db8d221aa6 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.cc +++ b/paddle/fluid/distributed/collective/ProcessGroup.cc @@ -52,5 +52,13 @@ ProcessGroup::ProcessGroup(int rank, } } +ProcessGroup::ProcessGroup(int rank, int size, int gid) + : rank_(rank), size_(size), gid_(gid) { + if (gid != IGNORE_ID) { + auto map = ProcessGroupMapFromGid::getInstance(); + map->insert(gid_, this); + } +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index 3db2464e59a..8a858aa81c9 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -82,6 +82,9 @@ class ProcessGroup { int size, const platform::Place& place, int gid); + + explicit ProcessGroup(int rank, int size, int gid); + virtual ~ProcessGroup() {} int GetRank() const { return rank_; } diff --git a/paddle/fluid/distributed/collective/ProcessGroupMPI.cc b/paddle/fluid/distributed/collective/ProcessGroupMPI.cc new file mode 100644 index 00000000000..64be2e88c9f --- /dev/null +++ b/paddle/fluid/distributed/collective/ProcessGroupMPI.cc @@ -0,0 +1,467 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/collective/ProcessGroupMPI.h" +#include +#include "paddle/fluid/distributed/collective/Common.h" +#include "paddle/fluid/platform/cuda_device_guard.h" + +constexpr int64_t kWaitBlockTImeout = 10; +namespace paddle { +namespace distributed { + +std::map mpiDatatype = { + {phi::DataType::INT8, MPI_CHAR}, + {phi::DataType::UINT8, MPI_UNSIGNED_CHAR}, + {phi::DataType::FLOAT32, MPI_FLOAT}, + {phi::DataType::FLOAT64, MPI_DOUBLE}, + {phi::DataType::INT32, MPI_INT}, + {phi::DataType::INT64, MPI_LONG}}; + +void ProcessGroupMPI::MPITask::FinishMPITaskError(std::exception_ptr eptr) { + Finish(eptr); +} + +void ProcessGroupMPI::MPITask::FinishMPITask() { Finish(); } + +ProcessGroupMPI::MPIAsyncTask::MPIAsyncTask( + MPI_Request request, const std::vector& inputs) + : ProcessGroup::Task(-1, inputs, CommType::UNKNOWN), request_(request) { + memset(&status_, 0, sizeof(status_)); +} + +ProcessGroupMPI::MPIAsyncTask::~MPIAsyncTask() { + if (request_ != MPI_REQUEST_NULL) { + std::cerr << " Task has not completed, try to destruct async mpi task, " + << "exit the program." << std::endl; + std::terminate(); + } +} + +bool ProcessGroupMPI::MPIAsyncTask::IsCompleted() { + if (request_ == MPI_REQUEST_NULL) { + return true; + } + + std::unique_lock lock(pg_global_mutex); + int flag = 0; + MPI_CHECK(MPI_Test(&request_, &flag, &status_)); + if (request_ != MPI_REQUEST_NULL) { + return false; + } + + if (status_.MPI_ERROR != MPI_SUCCESS) { + AppearException(); + } + + return true; +} + +bool ProcessGroupMPI::MPIAsyncTask::Wait(std::chrono::milliseconds timeout) { + if (request_ == MPI_REQUEST_NULL) { + return true; + } + + std::unique_lock lock(pg_global_mutex); + MPI_CHECK(MPI_Wait(&request_, &status_)); + + if (status_.MPI_ERROR != MPI_SUCCESS) { + AppearException(); + std::rethrow_exception(exception_); + return false; + } + + return true; +} + +void ProcessGroupMPI::MPIAsyncTask::AppearException() { + std::array buf; + int len = buf.size(); + MPI_CHECK(MPI_Error_string(status_.MPI_ERROR, buf.data(), &len)); + exception_ = + std::make_exception_ptr(std::runtime_error(std::string(buf.data(), len))); +} + +void ProcessGroupMPI::MPIAsyncTask::SetOutputs( + std::vector& outputs) { + outputs_ = std::make_shared>(outputs); +} + +int ProcessGroupMPI::mpi_thread_support = 0; +std::mutex ProcessGroupMPI::pg_global_mutex; +std::once_flag ProcessGroupMPI::onceFlag; + +void ProcessGroupMPI::ExitMPI() { + std::unique_lock lock(pg_global_mutex); + MPI_CHECK(MPI_Finalize()); +} + +void ProcessGroupMPI::InitOneTimeMPI() { + std::call_once(onceFlag, []() { + MPI_CHECK(MPI_Init_thread( + nullptr, nullptr, MPI_THREAD_SERIALIZED, &mpi_thread_support)); + PADDLE_ENFORCE_EQ( + mpi_thread_support < MPI_THREAD_SERIALIZED, + false, + platform::errors::InvalidArgument("MPI supports the number of threads " + "less than MPI_THREAD_SERIALIZED. ")); + + std::atexit(ProcessGroupMPI::ExitMPI); + }); +} + +std::shared_ptr ProcessGroupMPI::CreateProcessGroupMPI( + const std::vector& ranks, int gid) { + InitOneTimeMPI(); + + MPI_Comm groupComm = MPI_COMM_WORLD; + int rank = -1; + int size = -1; + + { + std::lock_guard lock(pg_global_mutex); + + if (!ranks.empty()) { + MPI_Group worldGroup; + MPI_Group ranksGroup; + MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); + MPI_CHECK( + MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup)); + + constexpr int maxRetries = 3; + bool create_success = false; + MPI_Barrier(MPI_COMM_WORLD); + for (auto i = 0; i < maxRetries; i++) { + if (MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm)) { + create_success = true; + break; + } + } + MPI_CHECK(create_success); + MPI_CHECK(MPI_Group_free(&worldGroup)); + MPI_CHECK(MPI_Group_free(&ranksGroup)); + } + + if (groupComm != MPI_COMM_NULL) { + MPI_CHECK(MPI_Comm_rank(groupComm, &rank)); + MPI_CHECK(MPI_Comm_size(groupComm, &size)); + + PADDLE_ENFORCE_EQ( + rank < 0 || size < 0, + false, + platform::errors::InvalidArgument("get world_size or rank failed!")); + } + } + + if (groupComm == MPI_COMM_NULL) { + return std::shared_ptr(); + } + + VLOG(3) << "MPI Group Create Success! rank = " << rank << " size = " << size + << " group_id = " << gid; + + return std::make_shared(rank, size, groupComm, gid); +} + +ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pg_comm, int gid) + : ProcessGroup(rank, size, gid), stop_(false), pg_comm(pg_comm) { + PADDLE_ENFORCE_EQ( + pg_comm == MPI_COMM_NULL, + false, + platform::errors::InvalidArgument("Error! mpi comm is MPI_COMM_NULL!")); + + worker_thread = std::thread(&ProcessGroupMPI::workLoop, this); +} + +ProcessGroupMPI::~ProcessGroupMPI() { + std::unique_lock lock(pg_mutex); + queue_consume.wait(lock, [&] { return queue_.empty(); }); + stop_ = true; + lock.unlock(); + queue_produce.notify_all(); + + worker_thread.join(); +} + +void ProcessGroupMPI::workLoop() { + std::unique_lock lock(pg_mutex); + + while (!stop_) { + if (queue_.empty()) { + queue_produce.wait(lock); + continue; + } + + auto taskTuple = std::move(queue_.front()); + + queue_.pop_front(); + + auto& taskEntry = std::get<0>(taskTuple); + auto& task = std::get<1>(taskTuple); + + lock.unlock(); + queue_consume.notify_one(); + + try { + taskEntry->run_(taskEntry); + task->FinishMPITask(); + } catch (...) { + task->FinishMPITaskError(std::current_exception()); + } + + lock.lock(); + } +} + +std::shared_ptr ProcessGroupMPI::Enqueue( + std::unique_ptr entry, + const std::vector& inputs) { + auto task = std::make_shared(entry->dst_, inputs); + std::unique_lock lock(pg_mutex); + queue_.push_back(std::make_tuple(std::move(entry), task)); + lock.unlock(); + queue_produce.notify_one(); + return task; +} + +std::shared_ptr ProcessGroupMPI::Broadcast( + std::vector& in_tensors, + std::vector& out_tensors, + const BroadcastOptions& opts) { + mpi::CheckValidInputs(in_tensors); + const auto places = GetPlaceList(in_tensors); + + std::function&)> runFunc = + [opts, this](std::unique_ptr& entry) { + auto data = (entry->src_)[0]; + std::unique_lock lock(pg_global_mutex); + const auto root = opts.source_rank + opts.source_root; + MPI_CHECK(MPI_Bcast(data.data(), + data.numel(), + mpiDatatype.at(data.dtype()), + root, + pg_comm)); + }; + auto entry = std::make_unique( + &in_tensors, &out_tensors, std::move(runFunc)); + return Enqueue(std::move(entry), in_tensors); +} + +std::shared_ptr ProcessGroupMPI::AllReduce( + std::vector& in_tensors, + std::vector& out_tensors, + const AllreduceOptions& opts) { + mpi::CheckValidInputs(in_tensors); + + std::function&)> runFunc = + [opts, this](std::unique_ptr& entry) { + auto data = (entry->src_)[0]; + std::unique_lock lock(pg_global_mutex); + MPI_CHECK(MPI_Allreduce(MPI_IN_PLACE, + data.data(), + data.numel(), + mpiDatatype.at(data.dtype()), + mpi::ToMPIType(opts.reduce_op), + pg_comm)); + }; + auto entry = std::make_unique( + &in_tensors, &out_tensors, std::move(runFunc)); + return Enqueue(std::move(entry), in_tensors); +} + +std::shared_ptr ProcessGroupMPI::Barrier( + const BarrierOptions& opts) { + std::function&)> runFunc = + [this](std::unique_ptr& entry) { + std::unique_lock lock(pg_global_mutex); + MPI_CHECK(MPI_Barrier(pg_comm)); + }; + auto entry = + std::make_unique(nullptr, nullptr, std::move(runFunc)); + return Enqueue(std::move(entry), std::vector{}); +} + +// NOTE: MPI_send tag set gid_ +std::shared_ptr ProcessGroupMPI::Send( + std::vector& tensors, int dst_rank) { + mpi::CheckValidInputs(tensors); + + auto& tensor = tensors[0]; + MPI_Request request = MPI_REQUEST_NULL; + + { + std::unique_lock lock(pg_global_mutex); + MPI_CHECK(MPI_Isend(tensor.data(), + tensor.numel(), + mpiDatatype.at(tensor.dtype()), + dst_rank, + this->gid_, + pg_comm, + &request)); + } + + return std::make_shared(request, tensors); +} + +std::shared_ptr ProcessGroupMPI::Recv( + std::vector& tensors, int src_rank) { + mpi::CheckValidInputs(tensors); + + auto& tensor = tensors[0]; + MPI_Request request = MPI_REQUEST_NULL; + + { + std::unique_lock lock(pg_global_mutex); + MPI_CHECK(MPI_Irecv(tensor.data(), + tensor.numel(), + mpiDatatype.at(tensor.dtype()), + src_rank, + this->gid_, + pg_comm, + &request)); + } + + return std::make_shared(request, tensors); +} + +std::shared_ptr ProcessGroupMPI::AllGather( + std::vector& in_tensors, + std::vector& out_tensors) { + mpi::CheckValidInputs(in_tensors); + + PADDLE_ENFORCE_EQ(out_tensors.size() == 1, + true, + platform::errors::InvalidArgument( + "MPI only support a single tensor op.")); + + std::function&)> runFunc = + [this](std::unique_ptr& entry) { + auto data = (entry->src_)[0]; + std::vector dst = entry->dst_; + + std::unique_lock lock(pg_global_mutex); + MPI_CHECK(MPI_Allgather(data.data(), + data.numel(), + mpiDatatype.at(data.dtype()), + dst[0].data(), + data.numel(), + mpiDatatype.at(data.dtype()), + pg_comm)); + }; + + auto entry = std::make_unique( + &in_tensors, &out_tensors, std::move(runFunc)); + + return Enqueue(std::move(entry), in_tensors); +} + +std::shared_ptr ProcessGroupMPI::AllToAll( + std::vector& in_tensors, + std::vector& out_tensors) { + mpi::CheckValidInputs(in_tensors); + mpi::CheckValidInputs(out_tensors); + + PADDLE_ENFORCE_EQ(in_tensors[0].numel() == out_tensors[0].numel() && + in_tensors[0].dtype() == out_tensors[0].dtype(), + true, + platform::errors::InvalidArgument( + "MPI AlltoAll: input and output are not equal in " + "size or data type.")); + + std::function&)> runFunc = + [this](std::unique_ptr& entry) { + auto srcdata = (entry->src_)[0]; + auto dstdata = (entry->dst_)[0]; + std::unique_lock lock(pg_global_mutex); + MPI_CHECK(MPI_Alltoall(srcdata.data(), + srcdata.numel() / size_, + mpiDatatype.at(srcdata.dtype()), + dstdata.data(), + dstdata.numel() / size_, + mpiDatatype.at(dstdata.dtype()), + pg_comm)); + }; + auto entry = std::make_unique( + &in_tensors, &out_tensors, std::move(runFunc)); + + return Enqueue(std::move(entry), in_tensors); +} + +std::shared_ptr ProcessGroupMPI::Reduce( + std::vector& tensors, + std::vector& out_tensors, + const ReduceOptions& opts) { + mpi::CheckValidInputs(tensors); + + std::function&)> runFunc = + [opts, this](std::unique_ptr& entry) { + auto data = (entry->src_)[0]; + auto dataPtr = (entry->src_)[0].data(); + void* sendbuf = (rank_ == opts.root_rank) ? MPI_IN_PLACE : dataPtr; + void* recvbuf = (rank_ == opts.root_rank) ? dataPtr : nullptr; + + std::unique_lock lock(pg_global_mutex); + MPI_CHECK(MPI_Reduce(sendbuf, + recvbuf, + data.numel(), + mpiDatatype.at(data.dtype()), + mpi::ToMPIType(opts.reduce_op), + opts.root_rank, + pg_comm)); + }; + auto entry = + std::make_unique(&tensors, &tensors, std::move(runFunc)); + return Enqueue(std::move(entry), tensors); +} + +std::shared_ptr ProcessGroupMPI::Scatter( + std::vector& in_tensors, + std::vector& out_tensors, + const ScatterOptions& opts) { + mpi::CheckValidInputs(in_tensors); + + std::function&)> runFunc = + [opts, this](std::unique_ptr& entry) { + auto data = (entry->dst_)[0]; + void* sendbuf = nullptr; + + if (rank_ == opts.root_rank) { + std::vector& inputData = entry->src_; + sendbuf = inputData[0].data(); + } + + std::unique_lock lock(pg_global_mutex); + MPI_CHECK(MPI_Scatter(sendbuf, + data.numel(), + mpiDatatype.at(data.dtype()), + data.data(), + data.numel(), + mpiDatatype.at(data.dtype()), + opts.root_rank, + pg_comm)); + }; + + if (rank_ == opts.root_rank) { + auto entry = std::make_unique( + &in_tensors, &out_tensors, std::move(runFunc)); + return Enqueue(std::move(entry), in_tensors); + } else { + auto entry = + std::make_unique(nullptr, &out_tensors, std::move(runFunc)); + return Enqueue(std::move(entry), in_tensors); + } +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroupMPI.h b/paddle/fluid/distributed/collective/ProcessGroupMPI.h new file mode 100644 index 00000000000..6a23b2fbb0d --- /dev/null +++ b/paddle/fluid/distributed/collective/ProcessGroupMPI.h @@ -0,0 +1,211 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "paddle/fluid/distributed/collective/ProcessGroup.h" +#include "paddle/fluid/distributed/collective/Types.h" +#include "paddle/fluid/platform/device_context.h" + +#if defined(PADDLE_WITH_MPI) +#include "paddle/fluid/distributed/collective/MPITools.h" +#endif + +constexpr const char* MPI_BACKEND_NAME = "MPI"; + +namespace paddle { +namespace distributed { + +struct TaskEntry { + explicit TaskEntry(std::vector* src_ptr, + std::vector* dst_ptr, + std::function&)> run) + : dst_(dst_ptr ? *dst_ptr : std::vector()), + run_(std::move(run)) { + if (src_ptr) { + src_ = *src_ptr; + } + } + + TaskEntry(const TaskEntry&) = delete; + TaskEntry& operator=(const TaskEntry&) = delete; + + std::vector src_; + std::vector dst_; + + int* srcRank_ = nullptr; + std::function&)> run_; +}; + +class ProcessGroupMPI : public ProcessGroup { + public: + class MPITask : public ProcessGroup::Task { + public: + explicit MPITask(std::vector outputTensors, + const std::vector& inputTensors) + : ProcessGroup::Task(-1, inputTensors, CommType::UNKNOWN), + outputs_(std::move(outputTensors)) {} + + void Synchronize() { Wait(); } + + bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) { + std::unique_lock lock(mutex_); + if (timeout == kWaitTimeout) { + // This waits without a timeout. + cv_.wait(lock, [&] { return is_completed_; }); + } else { + // Waits for the user-provided timeout. + cv_.wait_for(lock, timeout, [&] { return is_completed_; }); + PADDLE_ENFORCE_EQ( + is_completed_, + true, + platform::errors::InvalidArgument("MPI operation timeout! ")); + } + if (exception_) { + std::rethrow_exception(exception_); + } + return true; + } + + protected: + friend class ProcessGroupMPI; + + private: + // about mpi + void Finish(std::exception_ptr exception = nullptr) { + is_completed_ = true; + exception_ = exception; + cv_.notify_all(); + } + void FinishMPITask(); + void FinishMPITaskError(std::exception_ptr eptr); + + std::vector outputs_; + std::condition_variable cv_; + std::exception_ptr exception_; + }; + + public: + class MPIAsyncTask : public ProcessGroup::Task { + public: + MPIAsyncTask(MPI_Request request, + const std::vector& inputs); + + bool IsCompleted(); + + void Synchronize() {} + + bool Wait(std::chrono::milliseconds timeout = kWaitTimeout); + + void SetOutputs(std::vector& outputs); // NOLINT + + virtual ~MPIAsyncTask(); + + protected: + void AppearException(); + + private: + std::shared_ptr> outputs_; + MPI_Request request_; + MPI_Status status_; + std::exception_ptr exception_; + }; + + ProcessGroupMPI(int rank, int size, MPI_Comm pgComm, int gid); + + virtual ~ProcessGroupMPI(); + + const std::string GetBackendName() const override { + return std::string(MPI_BACKEND_NAME); + } + + std::shared_ptr AllReduce( + std::vector& in_tensors, + std::vector& out_tensors, + const AllreduceOptions& = AllreduceOptions()) override; + + std::shared_ptr Broadcast( + std::vector& in_tensors, + std::vector& out_tensors, + const BroadcastOptions& = BroadcastOptions()) override; + + std::shared_ptr Barrier( + const BarrierOptions& = BarrierOptions()) override; + + std::shared_ptr Send( + std::vector& tensors, int dst_rank) override; + + std::shared_ptr Recv( + std::vector& tensors, int src_rank) override; + + std::shared_ptr AllGather( + std::vector& in_tensors, + std::vector& out_tensors) override; + + std::shared_ptr AllToAll( + std::vector& in, + std::vector& out) override; + + std::shared_ptr Reduce( + std::vector& tensors, + std::vector& out_tensors, + const ReduceOptions& opts) override; + + std::shared_ptr Scatter( + std::vector& in_tensors, + std::vector& out_tensors, + const ScatterOptions&) override; + + static std::shared_ptr CreateProcessGroupMPI( + const std::vector& ranks, int gid); + + protected: + void workLoop(); + + std::shared_ptr Enqueue( + std::unique_ptr entry, + const std::vector& inputs); + + private: + bool stop_{false}; + std::mutex pg_mutex; + std::thread worker_thread; + std::deque, std::shared_ptr>> + queue_; + std::condition_variable queue_produce; + std::condition_variable queue_consume; + + static void InitOneTimeMPI(); + static void ExitMPI(); + static std::once_flag onceFlag; + + static std::mutex pg_global_mutex; + static int mpi_thread_support; + + MPI_Comm pg_comm; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 7e5bf07ea31..f2370bf1593 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -151,6 +151,9 @@ if(WITH_PYTHON) if(WITH_GLOO) set(PYBIND_DEPS ${PYBIND_DEPS} processgroup_gloo) endif() + if(WITH_MPI) + set(PYBIND_DEPS ${PYBIND_DEPS} processgroup_mpi) + endif() if(WITH_ASCEND_CL) set(PYBIND_DEPS ${PYBIND_DEPS} processgroup_hccl) if(WITH_PSCORE) @@ -591,6 +594,10 @@ if(WITH_PYTHON) target_link_libraries(libpaddle ${ROCM_HIPRTC_LIB}) endif() + if(WITH_MPI) + target_link_libraries(libpaddle ${MPI_CXX_LIBRARIES}) + endif() + get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) target_link_libraries(libpaddle ${os_dependency_modules}) add_dependencies(libpaddle op_function_generator_cmd) diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 8a434f42811..86ba7b9d37d 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -36,6 +36,10 @@ limitations under the License. */ #include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h" #endif +#if defined(PADDLE_WITH_MPI) +#include "paddle/fluid/distributed/collective/ProcessGroupMPI.h" +#endif + #if defined(PADDLE_WITH_ASCEND_CL) #include "paddle/fluid/distributed/collective/ProcessGroupHCCL.h" #endif @@ -623,6 +627,25 @@ void BindDistributed(py::module *m) { #endif +#if defined(PADDLE_WITH_MPI) + py::class_>( + *m, "ProcessGroupMPI", ProcessGroup) + .def_static( + "create", + [](const std::vector &ranks, + int gid) -> std::shared_ptr { + return paddle::distributed::ProcessGroupMPI::CreateProcessGroupMPI( + ranks, gid); + }) + .def("get_rank", + &distributed::ProcessGroup::GetRank, + py::call_guard()) + .def("get_world_size", + &distributed::ProcessGroup::GetSize, + py::call_guard()); +#endif + #if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_PSCORE) && \ (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_ASCEND_CL)) py::class_