// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include "paddle/fluid/distributed/collective/ProcessGroup.h" #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/stream/cuda_stream.h" #if defined(PADDLE_WITH_NCCL) #include "paddle/fluid/distributed/collective/NCCLTools.h" #include "paddle/fluid/platform/dynload/nccl.h" #endif constexpr const char* NCCL_BACKEND_NAME = "NCCL"; namespace paddle { namespace distributed { using Place = paddle::platform::Place; using CUDAStream = platform::stream::CUDAStream; using CUDADeviceContext = paddle::platform::CUDADeviceContext; class ProcessGroupNCCL : public ProcessGroup { public: class NCCLTask : public ProcessGroup::Task, public std::enable_shared_from_this { public: NCCLTask(const std::vector& places, int rank, CommType CommType, const std::vector& inputs); bool IsCompleted(); void SynchronizeStreams(); bool Wait(std::chrono::milliseconds timeout = kWaitTimeout); void Synchronize(); void SetOutputs(std::vector& outputs); // NOLINT virtual ~NCCLTask(); std::vector control_events_; protected: std::vector places_; std::vector> ncclComms_; std::shared_ptr> outputs_; private: }; ProcessGroupNCCL(const ProcessGroupStrategy& strategy, int rank, int size); const std::string GetBackendName() const override { return std::string(NCCL_BACKEND_NAME); } std::shared_ptr AllReduce( std::vector& tensors, const AllreduceOptions& = AllreduceOptions()) override; std::shared_ptr Broadcast( std::vector& tensors, const BroadcastOptions& = BroadcastOptions()) override; protected: virtual std::shared_ptr CreateTask( std::vector places, int rank, CommType opType, const std::vector& inputs); protected: ProcessGroupStrategy strategy_; std::shared_ptr nccl_comm_; std::mutex mutex_; std::unordered_map>> places_to_ncclcomm_; std::unordered_map> places_to_events_; std::unordered_map>> places_to_ctx_; private: void BcastNCCLId(std::vector& nccl_ids, int root, // NOLINT int server_fd); void BroadcastUniqueNCCLID(std::vector& nccl_ids); // NOLINT template std::shared_ptr Collective( std::vector& inputs, // NOLINT std::vector& outputs, // NOLINT Fn fn, CommType op_type); void CreateNCCLManagerCache(const std::string& places_key, const std::vector& places); }; } // namespace distributed } // namespace paddle