// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include #include #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/types.h" #include "paddle/fluid/platform/device_context.h" #if defined(PADDLE_WITH_MPI) #include "paddle/fluid/distributed/collective/mpi_tools.h" #endif namespace paddle { namespace distributed { struct TaskEntry { explicit TaskEntry(std::vector* src_ptr, std::vector* dst_ptr, std::function&)> run) : dst_(dst_ptr ? *dst_ptr : std::vector()), run_(std::move(run)) { if (src_ptr) { src_ = *src_ptr; } } TaskEntry(const TaskEntry&) = delete; TaskEntry& operator=(const TaskEntry&) = delete; std::vector src_; std::vector dst_; int* srcRank_ = nullptr; std::function&)> run_; }; class ProcessGroupMPI : public ProcessGroup { public: class MPITask : public ProcessGroup::Task { public: explicit MPITask(std::vector outputTensors, const std::vector& inputTensors) : ProcessGroup::Task(-1, inputTensors, CommType::UNKNOWN), outputs_(std::move(outputTensors)) {} void Synchronize() { Wait(); } bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) { std::unique_lock lock(mutex_); if (timeout == kWaitTimeout) { // This waits without a timeout. cv_.wait(lock, [&] { return is_completed_; }); } else { // Waits for the user-provided timeout. cv_.wait_for(lock, timeout, [&] { return is_completed_; }); PADDLE_ENFORCE_EQ( is_completed_, true, platform::errors::InvalidArgument("MPI operation timeout! ")); } if (exception_) { std::rethrow_exception(exception_); } return true; } protected: friend class ProcessGroupMPI; private: // about mpi void Finish(std::exception_ptr exception = nullptr) { is_completed_ = true; exception_ = exception; cv_.notify_all(); } void FinishMPITask(); void FinishMPITaskError(std::exception_ptr eptr); std::vector outputs_; std::condition_variable cv_; std::exception_ptr exception_; }; public: class MPIAsyncTask : public ProcessGroup::Task { public: MPIAsyncTask(MPI_Request request, const std::vector& inputs); bool IsCompleted(); void Synchronize() {} bool Wait(std::chrono::milliseconds timeout = kWaitTimeout); void SetOutputs(std::vector& outputs); // NOLINT virtual ~MPIAsyncTask(); protected: void AppearException(); private: std::shared_ptr> outputs_; MPI_Request request_; MPI_Status status_; std::exception_ptr exception_; }; ProcessGroupMPI(int rank, int size, MPI_Comm pgComm, int gid); virtual ~ProcessGroupMPI(); std::string GetBackendName() const override { return "MPI"; } std::shared_ptr AllReduce( std::vector& in_tensors, std::vector& out_tensors, const AllreduceOptions& = AllreduceOptions()) override; std::shared_ptr Broadcast( std::vector& in_tensors, std::vector& out_tensors, const BroadcastOptions& = BroadcastOptions()) override; std::shared_ptr Barrier( const BarrierOptions& = BarrierOptions()) override; std::shared_ptr Send( std::vector& tensors, int dst_rank) override; std::shared_ptr Recv( std::vector& tensors, int src_rank) override; std::shared_ptr AllGather( std::vector& in_tensors, std::vector& out_tensors) override; std::shared_ptr AllToAll( std::vector& in, std::vector& out) override; std::shared_ptr Reduce( std::vector& tensors, std::vector& out_tensors, const ReduceOptions& opts) override; std::shared_ptr Scatter( std::vector& in_tensors, std::vector& out_tensors, const ScatterOptions&) override; static std::shared_ptr CreateProcessGroupMPI( const std::vector& ranks, int gid); protected: void workLoop(); std::shared_ptr Enqueue( std::unique_ptr entry, const std::vector& inputs); private: bool stop_{false}; std::mutex pg_mutex; std::thread worker_thread; std::deque, std::shared_ptr>> queue_; std::condition_variable queue_produce; std::condition_variable queue_consume; static void InitOneTimeMPI(); static void ExitMPI(); static std::once_flag onceFlag; static std::mutex pg_global_mutex; static int mpi_thread_support; MPI_Comm pg_comm; }; } // namespace distributed } // namespace paddle