// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include "paddle/fluid/distributed/collective/Types.h" #include "paddle/fluid/eager/api/utils/tensor_utils.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/enforce.h" constexpr auto kWaitTimeout = std::chrono::milliseconds(0); namespace paddle { namespace distributed { using Tensor = paddle::experimental::Tensor; enum class CommType : std::uint8_t { BROADCAST = 0, ALLREDUCE = 1, ALLREDUCE_SPARSE = 2, // TODO(shenliang03): to support sparse in allreduce REDUCE = 3, ALLGATHER = 4, GATHER = 5, SCATTER = 6, REDUCE_SCATTER = 7, ALLTOALL = 8, SEND = 9, RECV = 10, BARRIER = 11, UNKNOWN = 100, }; struct ProcessGroupStrategy { int nranks_{1}; int local_rank_{0}; std::vector trainer_endpoints_{}; std::string current_endpoint_{""}; int nrings_{1}; }; class ProcessGroup { public: class Task { public: Task(int rank, const std::vector& inputTensors, CommType opType = CommType::UNKNOWN); virtual ~Task(); virtual bool IsCompleted(); virtual bool Wait(std::chrono::milliseconds timeout = kWaitTimeout); virtual void Synchronize(); protected: const int rank_; CommType comm_type_; std::mutex mutex_; bool is_completed_ = false; }; explicit ProcessGroup(int rank, int size); virtual ~ProcessGroup() {} int GetRank() const { return rank_; } int GetSize() const { return size_; } virtual const std::string GetBackendName() const = 0; virtual std::shared_ptr AllReduce( std::vector& /* tensors */, const AllreduceOptions& = AllreduceOptions()) { PADDLE_THROW(platform::errors::InvalidArgument( "ProcessGroup%s does not support allreduce", GetBackendName())); } virtual std::shared_ptr Broadcast( std::vector& /* tensors */, const BroadcastOptions& = BroadcastOptions()) { PADDLE_THROW(platform::errors::InvalidArgument( "ProcessGroup%s does not support allreduce", GetBackendName())); } protected: const int rank_; const int size_; }; } // namespace distributed } // namespace paddle