// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include "paddle/fluid/distributed/collective/ProcessGroup.h" #include "paddle/fluid/distributed/collective/ProcessGroupGloo.h" #include "paddle/fluid/platform/device_context.h" #ifdef PADDLE_WITH_GLOO #include "paddle/fluid/framework/fleet/gloo_wrapper.h" #endif #include "paddle/fluid/distributed/store/store.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/fluid/platform/place.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/distributed/collective/NCCLTools.h" #include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h" #include "paddle/fluid/platform/cuda_device_guard.h" #endif #if defined(PADDLE_WITH_ASCEND_CL) #include "paddle/fluid/distributed/collective/HCCLTools.h" #include "paddle/fluid/distributed/collective/ProcessGroupHCCL.h" #endif #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ defined(PADDLE_WITH_ASCEND_CL)) #include "paddle/fluid/distributed/ps/service/heter_client.h" #endif #include "paddle/fluid/distributed/collective/Common.h" constexpr const char* HETER_BACKEND_NAME = "HETER_BACKEND"; namespace paddle { namespace distributed { using Place = paddle::platform::Place; class ProcessGroupHeter : public ProcessGroup { public: class HeterTask : public ProcessGroup::Task, public std::enable_shared_from_this { public: HeterTask(int rank, CommType CommType, const std::vector&); bool IsCompleted(); void SynchronizeStreams() {} bool Wait(std::chrono::milliseconds timeout = kWaitTimeout); void Synchronize() {} virtual ~HeterTask(); }; ProcessGroupHeter(const std::shared_ptr& store, int rank, int size, const platform::Place& place, int gid, int local_rank, int local_size, int gloo_rank, int gloo_size, bool with_switch, std::string switch_endpoints, int src_rank, int dst_rank); const std::string GetBackendName() const override { return std::string(HETER_BACKEND_NAME); } std::shared_ptr AllReduce( std::vector&, std::vector&, const AllreduceOptions& = AllreduceOptions()) override; std::shared_ptr Broadcast( std::vector&, std::vector&, const BroadcastOptions& = BroadcastOptions()) override; std::shared_ptr Send( std::vector& in_tensors, int peer) override; std::shared_ptr Recv( std::vector& out_tensors, int peer) override; protected: virtual std::shared_ptr CreateTask( int rank, CommType opType, const std::vector& inputs); private: std::shared_ptr store_; std::shared_ptr inner_pg_; std::shared_ptr inter_pg_; int local_rank_; int local_size_; int gloo_rank_; int gloo_size_; bool with_switch_; std::string switch_endpoint_; int src_rank_; int dst_rank_; static int send_count; static int recv_count; }; } // namespace distributed } // namespace paddle