// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include "paddle/fluid/distributed/collective/custom_ccl_tools.h" #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group_with_stream.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/fluid/platform/place.h" #include "paddle/phi/core/distributed/store/store.h" namespace paddle { namespace distributed { using Place = paddle::platform::Place; using CustomDeviceContext = paddle::platform::CustomDeviceContext; class ProcessGroupCustom : public ProcessGroupWithStream { public: class CustomTask : public ProcessGroup::Task, public std::enable_shared_from_this { public: CustomTask(const std::vector& places, int rank, CommType CommType, const std::vector& inputs); bool IsCompleted() override; void SynchronizeStreams(); bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override; void Synchronize() override; void UpdateWaitChain(const phi::DeviceContext& ctx) override; void SetOutputs(std::vector& outputs); // NOLINT virtual ~CustomTask(); std::vector control_events_; std::vector barrierTensors_; protected: std::vector places_; std::vector> cclComms_; std::shared_ptr> outputs_; private: const std::string device_type_; }; ProcessGroupCustom(const std::shared_ptr& store, const std::string& device_type, int rank, int size, int gid); static std::shared_ptr CreateProcessGroupCustom( const std::shared_ptr& store, const std::string& device_type, int rank, int size, int gid); std::string GetBackendName() const override { return "XCCL_" + device_type_; } std::shared_ptr Barrier( const BarrierOptions& = BarrierOptions()) override; phi::DeviceContext* GetDeviceContext(const Place& place) const override; phi::ccl::CCLComm CustomCCLComm(const Place& place) const; // TODO(sunyilun): methods below will be removed later std::shared_ptr AllGather( std::vector& in_tensors, std::vector& out_tensors) override; std::shared_ptr AllGather( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, int64_t offset, int64_t numel, bool sync_op, bool use_calc_stream) override; std::shared_ptr AllGather( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, int64_t offset, int64_t numel, bool sync_op) override; std::shared_ptr AllReduce( std::vector& in_tensors, std::vector& out_tensors, const AllreduceOptions& = AllreduceOptions()) override; std::shared_ptr AllReduce( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, const AllreduceOptions& opts, bool sync_op, bool use_calc_stream) override; std::shared_ptr AllReduce( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, const AllreduceOptions& opts, bool sync_op) override; std::shared_ptr Broadcast( std::vector& in_tensors, std::vector& out_tensors, const BroadcastOptions& = BroadcastOptions()) override; std::shared_ptr Broadcast( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, const BroadcastOptions& opts, bool sync_op, bool use_calc_stream) override; std::shared_ptr Broadcast( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, const BroadcastOptions& opts, bool sync_op) override; protected: virtual std::shared_ptr CreateTask( std::vector places, int rank, CommType opType, const std::vector& inputs); std::shared_ptr store_; std::shared_ptr custom_comm_; std::mutex mutex_; std::unordered_map>> places_to_customcomm_; std::unordered_map> places_to_events_; std::unordered_map>> places_to_ctx_; std::set used_place_ids_; private: void BcastCustomId(std::vector& ccl_ids, // NOLINT int root, int server_fd); void BroadcastUniqueCustomID( std::vector& custom_ccl_ids); // NOLINT template std::shared_ptr Collective( std::vector& inputs, // NOLINT std::vector& outputs, // NOLINT Fn fn, CommType op_type, bool sync_op, bool use_calc_stream); void CreateCustomManagerCache(const std::string& places_key, const std::vector& places); const std::string device_type_; }; } // namespace distributed } // namespace paddle