diff --git a/paddle/fluid/distributed/collective/custom_ccl_tools.cc b/paddle/fluid/distributed/collective/custom_ccl_tools.cc index 9d95bcf3588d5d59bba059f8cbff014c67ab2407..ccafcf12a6c26f55ab093288cc0674ad142620f0 100644 --- a/paddle/fluid/distributed/collective/custom_ccl_tools.cc +++ b/paddle/fluid/distributed/collective/custom_ccl_tools.cc @@ -18,7 +18,7 @@ namespace paddle { namespace distributed { -phi::ccl::CCLReduceOp ToCustomCCLRedType(ReduceOp reduction) { +phi::ccl::CCLReduceOp ToXCCLRedType(ReduceOp reduction) { static const std::map red_type = { {ReduceOp::MIN, phi::ccl::CCLReduceOp::MIN}, {ReduceOp::MAX, phi::ccl::CCLReduceOp::MAX}, @@ -34,14 +34,5 @@ phi::ccl::CCLReduceOp ToCustomCCLRedType(ReduceOp reduction) { return it->second; } -std::string SerializeCustomCCLUniqueId(const phi::ccl::CCLRootId& ccl_id) { - const uint8_t* bytes = ccl_id.data(); - std::ostringstream oss; - for (size_t i = 0; i < ccl_id.size(); ++i) { - oss << std::hex << static_cast(bytes[i]); - } - return oss.str(); -} - } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/custom_ccl_tools.h b/paddle/fluid/distributed/collective/custom_ccl_tools.h index 4388c6075485041f457835abfa6521956ff4a7fa..95557079a8252d11ca631e5cdb256155638a1181 100644 --- a/paddle/fluid/distributed/collective/custom_ccl_tools.h +++ b/paddle/fluid/distributed/collective/custom_ccl_tools.h @@ -34,170 +34,7 @@ namespace paddle { namespace distributed { -class CustomEventManager { - public: - CustomEventManager() = default; - - ~CustomEventManager() { - if (is_created_) { - event_->Destroy(); - } - } - - CustomEventManager(const CustomEventManager&) = delete; - CustomEventManager& operator=(const CustomEventManager&) = delete; - - CustomEventManager(CustomEventManager&& other) { - std::swap(is_created_, other.is_created_); - std::swap(device_index_, other.device_index_); - std::swap(device_type_, other.device_type_); - std::swap(event_, other.event_); - } - - CustomEventManager& operator=(CustomEventManager&& other) { - std::swap(is_created_, other.is_created_); - std::swap(device_index_, other.device_index_); - std::swap(device_type_, other.device_type_); - std::swap(event_, other.event_); - return *this; - } - - bool IsCreated() const { return is_created_; } - int8_t DeviceId() const { return device_index_; } - std::string DeviceType() const { return device_type_; } - phi::event::event_t GetRawCustomEvent() const { return event_->raw_event(); } - phi::event::Event* GetCustomEvent() const { return event_.get(); } - - void Record(const paddle::platform::CustomDeviceContext& ctx) { - auto place = ctx.GetPlace(); - auto device_type = place.GetDeviceType(); - auto device_index = place.GetDeviceId(); - if (!is_created_) { - CreateEvent(place); - } - PADDLE_ENFORCE_EQ(device_index, - device_index_, - platform::errors::PreconditionNotMet( - "CustomDeviceContext's device %d does not match" - "Event's device %d", - device_index, - device_index_)); - PADDLE_ENFORCE_EQ(device_type, - device_type_, - platform::errors::PreconditionNotMet( - "CustomDeviceContext's device %d does not match" - "Event's device type %d", - device_type, - device_type_)); - - phi::DeviceGuard guard(place); - phi::stream::Stream stream(place, ctx.stream()); - event_->Record(&stream); - } - - bool Query() const { return event_->Query(); } - - void Block(const paddle::platform::CustomDeviceContext& ctx) const { - if (is_created_) { - auto place = ctx.GetPlace(); - auto device_type = place.GetDeviceType(); - auto device_index = place.GetDeviceId(); - PADDLE_ENFORCE_EQ(device_index, - device_index_, - platform::errors::PreconditionNotMet( - "CustomDeviceContext's device %d does not match" - "Event's device %d", - device_index, - device_index_)); - PADDLE_ENFORCE_EQ(device_type, - device_type_, - platform::errors::PreconditionNotMet( - "CustomDeviceContext's device %d does not match" - "Event's device type %d", - device_type, - device_type_)); - phi::DeviceGuard guard(place); - phi::stream::Stream stream(place, ctx.stream()); - stream.WaitEvent(event_.get()); - } - } - - private: - bool is_created_{false}; - std::shared_ptr event_{nullptr}; - int8_t device_index_{0}; - std::string device_type_; - - private: - void CreateEvent(const platform::Place& place) { - device_index_ = place.GetDeviceId(); - device_type_ = place.GetDeviceType(); - event_.reset(new phi::event::Event); - event_->Init(place); - is_created_ = true; - } -}; - -class CustomCCLCommManager { - public: - CustomCCLCommManager(const std::string& device_type, - phi::ccl::CCLComm ccl_comm) - : device_type_(device_type), ccl_comm_(ccl_comm) {} - - CustomCCLCommManager() : CustomCCLCommManager("", nullptr) {} - - ~CustomCCLCommManager() noexcept { - std::unique_lock lock(mutex_); - if (phi::DeviceManager::HasDeviceType(device_type_) && ccl_comm_) { - phi::DeviceManager::CCLDestroyComm(device_type_, ccl_comm_); - } - } - - static std::shared_ptr Create( - const std::string& device_type, - int num_ranks, - int rank, - phi::ccl::CCLRootId* comm_id, - phi::ccl::CCLComm* ccl_comm) { - auto custom_ccl_manager = std::make_shared(); - phi::DeviceManager::CCLCommInitRank( - device_type, num_ranks, comm_id, rank, ccl_comm); - custom_ccl_manager->device_type_ = device_type; - custom_ccl_manager->ccl_id_ = comm_id; - custom_ccl_manager->rank_ = rank; - custom_ccl_manager->ccl_comm_ = *ccl_comm; - return custom_ccl_manager; - } - - phi::ccl::CCLRootId* GetCustomCCLId() const { - std::unique_lock lock(mutex_); - return ccl_id_; - } - - phi::ccl::CCLComm GetCustomCCLComm() const { - std::unique_lock lock(mutex_); - return ccl_comm_; - } - - CustomCCLCommManager(const CustomCCLCommManager&) = delete; - CustomCCLCommManager& operator=(const CustomCCLCommManager&) = delete; - CustomCCLCommManager& operator=(CustomCCLCommManager&& other) = delete; - - CustomCCLCommManager(CustomCCLCommManager&& other) { - std::unique_lock lock(other.mutex_); - std::swap(ccl_comm_, other.ccl_comm_); - } - - protected: - std::string device_type_; - phi::ccl::CCLComm ccl_comm_; - phi::ccl::CCLRootId* ccl_id_; - int rank_; - mutable std::mutex mutex_; -}; - -phi::ccl::CCLReduceOp ToCustomCCLRedType(ReduceOp reduction); -std::string SerializeCustomCCLUniqueId(const phi::ccl::CCLRootId& ccl_id); +phi::ccl::CCLReduceOp ToXCCLRedType(ReduceOp reduction); } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/process_group_custom.cc b/paddle/fluid/distributed/collective/process_group_custom.cc index 9ad124341c374e4241319ac7d76c09ba932bb75a..6d5c30da3133b577938fdb3a5822a2bea4ff87f0 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.cc +++ b/paddle/fluid/distributed/collective/process_group_custom.cc @@ -17,99 +17,62 @@ #include "paddle/fluid/distributed/collective/common.h" #include "paddle/fluid/distributed/collective/custom_ccl_tools.h" #include "paddle/fluid/distributed/collective/utils.h" -#include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/place.h" #include "paddle/phi/api/lib/utils/allocator.h" -#include "paddle/phi/common/place.h" #include "paddle/phi/core/distributed/check/static_check.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/flags.h" +#include "paddle/phi/core/utils/data_type.h" -DECLARE_bool(xccl_blocking_wait); +#include "paddle/phi/core/distributed/comm_context_manager.h" constexpr int64_t kWaitBlockTImeout = 10; +DECLARE_bool(use_stream_safe_cuda_allocator); + namespace paddle { namespace distributed { -void SyncDefaultStream( - const std::vector& places, - std::vector& cclEvents, // NOLINT - std::vector>& dev_ctx) { // NOLINT - for (size_t i = 0; i < places.size(); ++i) { - auto* default_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(places[i])); - cclEvents[i].Record(*default_ctx); - cclEvents[i].Block(*dev_ctx[i]); - } -} - -std::shared_ptr ProcessGroupCustom::CreateTask( - std::vector places, - int rank, - CommType comm_type, - const std::vector& inputs) { - return std::make_shared( - places, rank, comm_type, inputs); -} - -ProcessGroupCustom::CustomTask::CustomTask( - const std::vector& places, - int rank, - CommType CommType, - const std::vector& inputs) - : Task(rank, inputs, CommType), places_(places) { - control_events_.resize(places.size()); - cclComms_.resize(places.size()); +ProcessGroupCustom::XCCLTask::XCCLTask(const Place& place, + int rank, + CommType comm_type, + bool sync_op, + bool use_calc_stream) + : TaskStream(rank, comm_type, sync_op, use_calc_stream), + task_place_(place) { + comm_event_.Init(place); } -ProcessGroupCustom::CustomTask::~CustomTask() {} +ProcessGroupCustom::XCCLTask::~XCCLTask() = default; -void ProcessGroupCustom::CustomTask::SetOutputs( - std::vector& outputs) { // NOLINT - outputs_ = std::make_shared>(outputs); -} +bool ProcessGroupCustom::XCCLTask::IsCompleted() { return comm_event_.Query(); } -void ProcessGroupCustom::CustomTask::SynchronizeStreams() { - for (size_t i = 0; i < places_.size(); ++i) { - auto* default_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(places_[i])); - phi::DeviceGuard guard(default_ctx->GetPlace()); - control_events_[i].Block(*default_ctx); - } +void ProcessGroupCustom::XCCLTask::UpdateWaitChain( + const phi::DeviceContext& ctx) { + comm_event_.Record( + reinterpret_cast(ctx).GetStream().get()); } -bool ProcessGroupCustom::CustomTask::IsCompleted() { - for (size_t i = 0; i < places_.size(); ++i) { - if (!control_events_[i].Query()) { - return false; - } +bool ProcessGroupCustom::XCCLTask::Wait(std::chrono::milliseconds timeout) { + // Warning here when use calc stream but also invoke waiting explicitly. + if (UseCalcStream()) { + VLOG(3) << "Warning: The communication is on calc stream, wait here is " + "useless."; + return true; } - return true; -} + const auto* calc_ctx = reinterpret_cast( + platform::DeviceContextPool::Instance().Get(task_place_)); + calc_ctx->GetStream()->WaitEvent(&comm_event_); -bool ProcessGroupCustom::CustomTask::Wait(std::chrono::milliseconds timeout) { - SynchronizeStreams(); - while (!IsCompleted()) { - std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout)); + if (IsBlockCPUInWait()) { + // If we use the work to do barrier, we should block cpu + phi::DeviceManager::SynchronizeDevice(task_place_); } return true; } // Same as Wait -void ProcessGroupCustom::CustomTask::Synchronize() { Wait(kWaitTimeout); } - -void ProcessGroupCustom::CustomTask::UpdateWaitChain( - const phi::DeviceContext& ctx) { - PADDLE_ENFORCE_NE( - std::find(places_.cbegin(), places_.cend(), ctx.GetPlace()), - places_.cend(), - phi::errors::NotFound("Cannot find the device context in this task.")); - auto index = std::find(places_.cbegin(), places_.cend(), ctx.GetPlace()) - - places_.cbegin(); - control_events_[index].Record( - reinterpret_cast(ctx)); -} +void ProcessGroupCustom::XCCLTask::Synchronize() { Wait(kWaitTimeout); } ProcessGroupCustom::ProcessGroupCustom( const std::shared_ptr& store, @@ -121,147 +84,45 @@ ProcessGroupCustom::ProcessGroupCustom( store_(store), device_type_(device_type) {} -void ProcessGroupCustom::BroadcastUniqueCustomID( - std::vector& ccl_ids) { // NOLINT - if (rank_ == 0) { - for (size_t i = 0; i < ccl_ids.size(); i++) { - auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(gid_) + "/" + - std::to_string(i); - store_->set(key, ccl_ids[i]); - } - } else { - for (size_t i = 0; i < ccl_ids.size(); i++) { - auto key = "ProcessGroupCustom/ccl_ids/" + std::to_string(gid_) + "/" + - std::to_string(i); - ccl_ids[i] = store_->get(key); - } - } +void ProcessGroupCustom::GroupStart(const std::string& dev_type) { + phi::DeviceManager::CCLGroupStart(dev_type); } -// create CustomCCLManager cache for places_key -void ProcessGroupCustom::CreateCustomManagerCache( - const std::string& places_key, const std::vector& places) { - PADDLE_ENFORCE_EQ( - places_key.empty(), - false, - platform::errors::PreconditionNotMet( - "Not able to create/get the CustomCCL Communicator since " - "the NPU place are not known")); - const std::string device_type = places.back().GetDeviceType(); - - std::vector> ccl_comms; - ccl_comms.resize(places.size()); - - // using vector just for broadcast - std::vector ccl_ids; - ccl_ids.resize(1); - auto& ccl_id = ccl_ids.front(); - - if (rank_ == 0) { - phi::DeviceManager::CCLGetUniqueId(device_type, &ccl_id); - } - BroadcastUniqueCustomID(ccl_ids); - - VLOG(3) << "init custom ccl rank: " << rank_ << ", nranks: " << size_ - << ", place: " << places_key - << ", custom ccl uniqueid: " << SerializeCustomCCLUniqueId(ccl_id); - - std::vector> dev_ctx; - dev_ctx.resize(places.size()); - - for (size_t i = 0; i < places.size(); ++i) { - phi::DeviceGuard guard(places[i]); - ccl_comms[i] = CustomCCLCommManager::Create( - device_type, GetSize(), GetRank(), &ccl_id, new phi::ccl::CCLComm); - dev_ctx[i] = std::make_unique(places[i]); - dev_ctx[i]->SetAllocator( - &(phi::DeviceContextPool::Instance().Get(places[i])->GetAllocator())); - dev_ctx[i]->SetHostAllocator(&( - phi::DeviceContextPool::Instance().Get(places[i])->GetHostAllocator())); - dev_ctx[i]->SetZeroAllocator(&( - phi::DeviceContextPool::Instance().Get(places[i])->GetZeroAllocator())); - dev_ctx[i]->SetHostZeroAllocator(&(phi::DeviceContextPool::Instance() - .Get(places[i]) - ->GetHostZeroAllocator())); - } - - std::vector events; - events.resize(places.size()); - - // These caches will be useful to process sync/wait/communicate - places_to_events_.emplace(places_key, std::move(events)); - places_to_customcomm_.emplace(places_key, std::move(ccl_comms)); - places_to_ctx_.emplace(places_key, std::move(dev_ctx)); +void ProcessGroupCustom::GroupEnd(const std::string& dev_type) { + phi::DeviceManager::CCLGroupEnd(dev_type); } -template -std::shared_ptr ProcessGroupCustom::Collective( - std::vector& inputs, - std::vector& outputs, - Fn fn, - CommType op_type, - bool sync_op UNUSED, - bool use_calc_stream) { - const auto places = GetPlaceList(inputs); - const auto key = GetKeyFromPlaces(places); - - { - std::lock_guard lock(mutex_); - if (places_to_customcomm_.find(key) == places_to_customcomm_.end()) { - CreateCustomManagerCache(key, places); - } - } - - auto& ccl_comms = places_to_customcomm_[key]; - if (!use_calc_stream) { - SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); - } - auto task = CreateTask(places, rank_, op_type, inputs); - task->SetOutputs(outputs); - - for (size_t i = 0; i < inputs.size(); ++i) { - phi::DeviceGuard guard(places[i]); - const auto& ccl_stream = - use_calc_stream ? reinterpret_cast( - phi::DeviceContextPool::Instance().Get(places[i])) - ->stream() - : places_to_ctx_[key][i]->stream(); - phi::stream::Stream stream(places[i], ccl_stream); - fn(inputs[i], outputs[i], ccl_comms[i]->GetCustomCCLComm(), stream); - } - - if (!use_calc_stream) { - for (size_t i = 0; i < inputs.size(); ++i) { - phi::DeviceGuard guard(places[i]); - task->control_events_[i].Record(*places_to_ctx_[key][i]); - } - } - return task; +phi::DeviceContext* ProcessGroupCustom::GetDeviceContext( + const Place& place) const { + return GetDeviceContext(place, /*use_calc_stream*/ false); } -void* XcclGetPointerByOffset(void* raw_pointer, - size_t offset, - phi::DataType type) { - if (type == phi::DataType::FLOAT32) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::FLOAT64) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::INT32) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::INT64) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); - } else if (type == phi::DataType::FLOAT16) { - return reinterpret_cast(reinterpret_cast(raw_pointer) + - offset); +phi::DeviceContext* ProcessGroupCustom::GetDeviceContext( + const Place& place, bool use_calc_stream) const { + const std::string& key = GetKeyFromPlace(place); + if (use_calc_stream) { + const auto& iter = place_to_calc_ctx_.find(key); + return iter->second; } else { - PADDLE_THROW(platform::errors::Unimplemented( - "This datatype in xccl is not supported.")); + const auto& iter = place_to_comm_ctx_.find(key); + PADDLE_ENFORCE_NE( + iter, + place_to_comm_ctx_.end(), + phi::errors::NotFound( + "Cannot find the device context in this process group.")); + return iter->second.get(); } - return nullptr; +} + +phi::ccl::CCLComm ProcessGroupCustom::XCCLComm(const Place& place) const { + const std::string& key = GetKeyFromPlace(place); + const auto& iter = place_to_comm_ctx_.find(key); + PADDLE_ENFORCE_NE( + iter, + place_to_comm_ctx_.end(), + phi::errors::NotFound( + "Cannot find the XCCL communicator in this process group.")); + return iter->second->xccl_comm(); } std::shared_ptr ProcessGroupCustom::AllGather( @@ -269,212 +130,110 @@ std::shared_ptr ProcessGroupCustom::AllGather( const phi::DenseTensor& in_tensor, int64_t offset, int64_t numel, - bool sync_op, // for compatibility, no use now + bool sync_op, bool use_calc_stream) { // numel > 0 indicates the tensor need to be sliced const phi::DenseTensor& in_tensor_maybe_partial = - numel > 0 - ? paddle::distributed::GetPartialTensor(in_tensor, offset, numel) - : in_tensor; - phi::distributed::CommStaticCheck::GatherLikeShape( - *out_tensor, - in_tensor_maybe_partial, - /*dst_rank*/ rank_, - /*cur_rank*/ rank_, - size_, - phi::AllocationType::CUSTOM); - std::vector in_wrapper{in_tensor_maybe_partial}; - std::vector out_wrapper{*out_tensor}; - - return Collective( - in_wrapper, - out_wrapper, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - phi::ccl::CCLComm comm, - const phi::stream::Stream& stream) { - return phi::DeviceManager::CCLAllGather( - device_type_, - input.data(), - output.data(), - input.numel(), - phi::ccl::ToCCLDataType(input.dtype()), - comm, - stream); + numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor; + return RunFnInXCCLEnv( + [&](const phi::stream::Stream& stream) { + auto comm_context = this->GetCommContext(); + comm_context->AllGather(out_tensor, in_tensor_maybe_partial, stream); }, + in_tensor_maybe_partial, CommType::ALLGATHER, sync_op, use_calc_stream); } -// TODO(sunyilun): methods below will be removed later -std::shared_ptr ProcessGroupCustom::AllGather( - std::vector& in_tensors, - std::vector& out_tensors) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCustomPlace(in_tensors, device_type_), - true, - platform::errors::InvalidArgument( - "All inputs should be in CustomPlace(%s).", device_type_)); - PADDLE_ENFORCE_EQ( - CheckTensorsInCustomPlace(out_tensors, device_type_), - true, - platform::errors::InvalidArgument( - "All outputs should be in CustomPlace(%s).", device_type_)); - return Collective( - in_tensors, - out_tensors, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - phi::ccl::CCLComm comm, - const phi::stream::Stream& stream) { - return phi::DeviceManager::CCLAllGather( - device_type_, - input.data(), - output.data(), - input.numel(), - phi::ccl::ToCCLDataType(input.dtype()), - comm, - stream); - }, - CommType::ALLGATHER, - false, - false); -} - std::shared_ptr ProcessGroupCustom::AllReduce( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, const AllreduceOptions& opts, - bool sync_op, // for compatibility, no use now + bool sync_op, bool use_calc_stream) { - std::vector in_wrapper{in_tensor}; - std::vector out_wrapper{*out_tensor}; - PADDLE_ENFORCE_EQ( - CheckTensorsInCustomPlace(in_wrapper, device_type_), - true, - platform::errors::InvalidArgument( - "All inputs should be in CustomPlace(%s).", device_type_)); - PADDLE_ENFORCE_EQ( - CheckTensorsInCustomPlace(out_wrapper, device_type_), - true, - platform::errors::InvalidArgument( - "All outputs should be in CustomPlace(%s).", device_type_)); - return Collective( - in_wrapper, - out_wrapper, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - phi::ccl::CCLComm comm, - const phi::stream::Stream& stream) { - return phi::DeviceManager::CCLAllReduce( - device_type_, - input.data(), - output.data(), - input.numel(), - phi::ccl::ToCCLDataType(input.dtype()), - ToCustomCCLRedType(opts.reduce_op), - comm, + return RunFnInXCCLEnv( + [&](const phi::stream::Stream& stream) { + auto comm_context = this->GetCommContext(); + comm_context->AllReduce( + out_tensor, + in_tensor, + paddle::distributed::ToXCCLRedType(opts.reduce_op), stream); }, + in_tensor, CommType::ALLREDUCE, sync_op, use_calc_stream); } -std::shared_ptr ProcessGroupCustom::AllReduce( - std::vector& in_tensors, // NOLINT - std::vector& out_tensors, // NOLINT - const AllreduceOptions& opts) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCustomPlace(in_tensors, device_type_), - true, - platform::errors::InvalidArgument( - "All inputs should be in CustomPlace(%s).", device_type_)); - PADDLE_ENFORCE_EQ( - CheckTensorsInCustomPlace(out_tensors, device_type_), - true, - platform::errors::InvalidArgument( - "All outputs should be in CustomPlace(%s).", device_type_)); - return Collective( - in_tensors, - out_tensors, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - phi::ccl::CCLComm comm, - const phi::stream::Stream& stream) { - return phi::DeviceManager::CCLAllReduce( - device_type_, - input.data(), - output.data(), - input.numel(), - phi::ccl::ToCCLDataType(input.dtype()), - ToCustomCCLRedType(opts.reduce_op), - comm, - stream); - }, - CommType::ALLREDUCE, - false, - false); -} - -std::shared_ptr ProcessGroupCustom::Broadcast( +std::shared_ptr ProcessGroupCustom::AllToAll( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, - const BroadcastOptions& opts, - bool sync_op, // for compatibility, no use now + const std::vector& out_size_each_rank, + const std::vector& in_size_each_rank, + bool sync_op, bool use_calc_stream) { - std::vector in_wrapper{in_tensor}; - std::vector out_wrapper{*out_tensor}; - PADDLE_ENFORCE_EQ( - CheckTensorsInCustomPlace(in_wrapper, device_type_), - true, - platform::errors::InvalidArgument( - "All inputs should be in CustomPlace(%s).", device_type_)); - PADDLE_ENFORCE_EQ( - CheckTensorsInCustomPlace(out_wrapper, device_type_), - true, - platform::errors::InvalidArgument( - "All outputs should be in CustomPlace(%s).", device_type_)); - return Collective( - in_wrapper, - out_wrapper, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - phi::ccl::CCLComm comm, - const phi::stream::Stream& stream) { - int root = opts.source_rank * in_wrapper.size() + opts.source_root; - if (rank_ == root) { - return phi::DeviceManager::CCLBroadcast( - device_type_, - input.data(), - input.numel(), - phi::ccl::ToCCLDataType(input.dtype()), - root, - comm, - stream); - } else { - return phi::DeviceManager::CCLBroadcast( - device_type_, - output.data(), - output.numel(), - phi::ccl::ToCCLDataType(output.dtype()), - root, - comm, - stream); + const phi::DDim& out_dim = out_tensor->dims(); + const phi::DDim& in_dim = in_tensor.dims(); + CheckSizeOnEachRank(out_dim, out_size_each_rank, size_); + CheckSizeOnEachRank(in_dim, in_size_each_rank, size_); + + // NOTE: Since `all_to_all` needs other processes' participation, it cannot + // simply be covered by static checks. Factors are set to 0 here to skip the + // shape check. Its shape check will be done by dynamic checks with + // FLAGS_enable_xccl_dynamic_check. + return RunFnInXCCLEnv( + [&](const phi::stream::Stream& stream) { + auto comm_context = this->GetCommContext(); + + int64_t in_row_size = in_tensor.numel() / in_dim[0], + out_row_size = out_tensor->numel() / out_dim[0]; + int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0; + phi::DenseTensor input_partial, output_partial; + + std::vector send_buf, recv_buf; + std::vector send_count, recv_count; + std::vector send_dtype, recv_dtype; + for (auto i = 0; i < size_; i++) { + in_numel = in_size_each_rank[i] * in_row_size; + input_partial = GetPartialTensor(in_tensor, in_offset, in_numel); + out_numel = out_size_each_rank[i] * out_row_size; + output_partial = GetPartialTensor(*out_tensor, out_offset, out_numel); + in_offset += in_numel; + out_offset += out_numel; + send_buf.push_back(input_partial.data()); + recv_buf.push_back(output_partial.data()); + send_count.push_back(in_numel); + recv_count.push_back(out_numel); + send_dtype.push_back(phi::ccl::ToCCLDataType(input_partial.dtype())); + recv_dtype.push_back(phi::ccl::ToCCLDataType(output_partial.dtype())); } + + phi::DeviceManager::CCLAllToAll( + device_type_, + const_cast(send_buf.data()), + send_count.data(), + send_dtype.data(), + recv_buf.data(), + recv_count.data(), + recv_dtype.data(), + rank_, + size_, + comm_context->GetXcclComm(), + stream); }, - CommType::BROADCAST, + in_tensor, + CommType::ALLTOALL, sync_op, use_calc_stream); } std::shared_ptr ProcessGroupCustom::Barrier( const BarrierOptions& opts) { - // Only support single card single process PADDLE_ENFORCE_GE(opts.device_id, 0, - platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "The barrier device id must greater or equal than 0.")); platform::CustomPlace place(device_type_, opts.device_id); auto allocator = std::unique_ptr( @@ -482,111 +241,176 @@ std::shared_ptr ProcessGroupCustom::Barrier( phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1}); phi::DenseTensor barrier_tensor{allocator.get(), meta}; - auto task = ProcessGroupCustom::AllReduce(&barrier_tensor, - barrier_tensor, - {}, - /*sync_op*/ true, - false); - auto xccl_task = dynamic_cast(task.get()); - xccl_task->barrierTensors_ = {barrier_tensor}; + auto task = AllReduce(&barrier_tensor, + barrier_tensor, + {}, + /*sync_op*/ true, + /*use_calc_stream*/ false); + auto xccl_task = dynamic_cast(task.get()); + xccl_task->SetBlockCPUInWait(); return task; } -phi::DeviceContext* ProcessGroupCustom::GetDeviceContext( - const Place& place) const { - const std::string key = GetKeyFromPlace(place); - const auto& iter = places_to_ctx_.find(key); - PADDLE_ENFORCE_NE( - iter, - places_to_ctx_.end(), - platform::errors::NotFound( - "Cannot find the device context in this process group.")); - return iter->second[0].get(); -} - -phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm(const Place& place) const { - std::vector places = {place}; - const auto& iter = places_to_customcomm_.find(GetKeyFromPlaces(places)); - PADDLE_ENFORCE_NE(iter, - places_to_customcomm_.end(), - platform::errors::InvalidArgument( - "Cannot find nccl comm in process group.")); - return iter->second[0]->GetCustomCCLComm(); -} - std::shared_ptr ProcessGroupCustom::Broadcast( - std::vector& in_tensors, // NOLINT - std::vector& out_tensors, // NOLINT - const BroadcastOptions& opts) { - PADDLE_ENFORCE_EQ( - CheckTensorsInCustomPlace(in_tensors, device_type_), - true, - platform::errors::InvalidArgument( - "All inputs should be in CustomPlace(%s).", device_type_)); - PADDLE_ENFORCE_EQ( - CheckTensorsInCustomPlace(out_tensors, device_type_), - true, - platform::errors::InvalidArgument( - "All outputs should be in CustomPlace(%s).", device_type_)); - return Collective( - in_tensors, - out_tensors, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - phi::ccl::CCLComm comm, - const phi::stream::Stream& stream) { - int root = opts.source_rank * in_tensors.size() + opts.source_root; - if (rank_ == root) { - return phi::DeviceManager::CCLBroadcast( - device_type_, - input.data(), - input.numel(), - phi::ccl::ToCCLDataType(input.dtype()), - root, - comm, - stream); - } else { - return phi::DeviceManager::CCLBroadcast( - device_type_, - output.data(), - output.numel(), - phi::ccl::ToCCLDataType(output.dtype()), - root, - comm, - stream); - } + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const BroadcastOptions& opts, + bool sync_op, + bool use_calc_stream) { + return RunFnInXCCLEnv( + [&](const phi::stream::Stream& stream) { + int root = opts.source_rank + opts.source_root; + auto comm_context = this->GetCommContext(); + comm_context->Broadcast(out_tensor, in_tensor, root, stream); }, + in_tensor, CommType::BROADCAST, - false, - false); + sync_op, + use_calc_stream); } -void CheckTensorsInDifferentCustomDevices( - const std::vector& tensors, const size_t num_devices) { - PADDLE_ENFORCE_EQ( - tensors.size() == 0, - false, - phi::errors::InvalidArgument("Tensor list must be nonempty.")); - PADDLE_ENFORCE_LE( - tensors.size(), - num_devices, - phi::errors::InvalidArgument("Tensor list mustn't be larger than the " - "number of available CustomDevice.")); +std::shared_ptr ProcessGroupCustom::Reduce( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceOptions& opts, + bool sync_op, + bool use_calc_stream) { + return RunFnInXCCLEnv( + [&](const phi::stream::Stream& stream) { + auto comm_context = this->GetCommContext(); + comm_context->Reduce(out_tensor, + in_tensor, + paddle::distributed::ToXCCLRedType(opts.reduce_op), + opts.root_rank, + stream); + }, + in_tensor, + CommType::REDUCE, + sync_op, + use_calc_stream); +} - std::set used_devices; +std::shared_ptr ProcessGroupCustom::ReduceScatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceScatterOptions& opts, + bool sync_op, + bool use_calc_stream) { + return RunFnInXCCLEnv( + [&](const phi::stream::Stream& stream) { + auto comm_context = this->GetCommContext(); + comm_context->ReduceScatter( + out_tensor, + in_tensor, + paddle::distributed::ToXCCLRedType(opts.reduce_op), + stream); + }, + in_tensor, + CommType::REDUCE_SCATTER, + sync_op, + use_calc_stream); +} - for (const auto& t : tensors) { - PADDLE_ENFORCE_EQ(platform::is_custom_place(t.place()), - true, - phi::errors::InvalidArgument( - "Tensors must be CustomDevice and dense tensor.")); +std::shared_ptr ProcessGroupCustom::Scatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ScatterOptions& opts, + bool sync_op, + bool use_calc_stream) { + phi::distributed::CommStaticCheck::ScatterLikeShape( + *out_tensor, + in_tensor, + /*dst_rank*/ opts.root_rank, + /*cur_rank*/ rank_, + size_, + phi::AllocationType::CUSTOM); + return RunFnInXCCLEnv( + [&](const phi::stream::Stream& stream) { + auto comm_context = this->GetCommContext(); - const auto inserted = used_devices.insert(t.place()).second; - PADDLE_ENFORCE_EQ(inserted, - true, - phi::errors::InvalidArgument( - "Tensors must be on distinct custom devices.")); + int64_t numel = in_tensor.numel() / size_; + if (rank_ == opts.root_rank) { + int64_t offset = 0; + phi::DenseTensor partial_tensor; + for (auto i = 0; i < size_; i++) { + partial_tensor = GetPartialTensor(in_tensor, offset, numel); + if (i != rank_) { + comm_context->Send(partial_tensor, numel, i, stream); + } else { + phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace()) + ->MemoryCopyD2D(out_tensor->data(), + partial_tensor.data(), + numel * phi::SizeOf(partial_tensor.dtype()), + &stream); + } + offset += numel; + } + } else { + comm_context->Recv(out_tensor, numel, opts.root_rank, stream); + } + }, + in_tensor, + CommType::SCATTER, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupCustom::Gather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const GatherOptions& opts, + bool sync_op, + bool use_calc_stream) { + std::vector partial_tensors; + if (rank_ == opts.root_rank) { + partial_tensors.reserve(size_); + size_t offset = 0; + size_t numel = out_tensor->numel() / size_; + for (auto i = 0; i < size_; i++) { + partial_tensors.push_back(GetPartialTensor(*out_tensor, offset, numel)); + offset += numel; + } } + return Gather(&partial_tensors, in_tensor, opts, sync_op, use_calc_stream); +} + +std::shared_ptr ProcessGroupCustom::Gather( + std::vector* gather_tensors_ptr, + const phi::DenseTensor& in_tensor, + const GatherOptions& opts, + bool sync_op, + bool use_calc_stream) { + auto& gather_tensors = *gather_tensors_ptr; + PADDLE_ENFORCE_GT(size_, + opts.root_rank, + phi::errors::InvalidArgument( + "root world size [%d] is less than root rank [%d]", + size_, + opts.root_rank)); + auto gather_func = [&](const phi::stream::Stream& stream) { + auto comm_context = this->GetCommContext(); + // root receive from all devices + if (rank_ == opts.root_rank) { + for (auto i = 0; i < size_; i++) { + auto& gather_tensor = gather_tensors[i]; + if (i != rank_) { + comm_context->Recv(&gather_tensor, gather_tensor.numel(), i, stream); + } else { + phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace()) + ->MemoryCopyD2D( + gather_tensor.data(), + in_tensor.data(), + in_tensor.numel() * phi::SizeOf(in_tensor.dtype()), + &stream); + } + } + } else { + // send to root + comm_context->Send(in_tensor, in_tensor.numel(), opts.root_rank, stream); + } + }; + return RunFnInXCCLEnv( + gather_func, in_tensor, CommType::GATHER, sync_op, use_calc_stream); } std::shared_ptr ProcessGroupCustom::Recv( @@ -602,53 +426,18 @@ std::shared_ptr ProcessGroupCustom::Recv( partial_tensor = GetPartialTensor(*tensor, offset, numel); tensor = &partial_tensor; } - phi::distributed::CommStaticCheck::CheckShape( - *tensor, rank_, size_, phi::AllocationType::CUSTOM); - std::vector in_wrapper{*tensor}; - std::vector out_wrapper{*tensor}; - return Collective( - in_wrapper, - out_wrapper, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - phi::ccl::CCLComm comm, - const phi::stream::Stream& stream) { - phi::DeviceManager::CCLRecv(device_type_, - output.data(), - output.numel(), - phi::ccl::ToCCLDataType(output.dtype()), - src_rank, - comm, - stream); + + return RunFnInXCCLEnv( + [&](const phi::stream::Stream& stream) { + auto comm_context = this->GetCommContext(); + comm_context->Recv(tensor, tensor->numel(), src_rank, stream); }, + *tensor, CommType::RECV, sync_op, use_calc_stream); } -std::shared_ptr ProcessGroupCustom::Recv( - std::vector& tensors, int src_rank) { - CheckTensorsInDifferentCustomDevices(tensors, static_cast(GetSize())); - return Collective( - tensors, - tensors, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - phi::ccl::CCLComm comm, - const phi::stream::Stream& stream) { - phi::DeviceManager::CCLRecv(device_type_, - output.data(), - output.numel(), - phi::ccl::ToCCLDataType(output.dtype()), - src_rank, - comm, - stream); - }, - CommType::RECV, - false, - false); -} - std::shared_ptr ProcessGroupCustom::Send( const phi::DenseTensor& tensor, int dst_rank, @@ -659,192 +448,459 @@ std::shared_ptr ProcessGroupCustom::Send( // numel > 0 indicates the tensor need to be sliced const phi::DenseTensor& tensor_maybe_partial = numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor; - phi::distributed::CommStaticCheck::CheckShape( - tensor_maybe_partial, rank_, size_, phi::AllocationType::CUSTOM); - std::vector in_wrapper{tensor_maybe_partial}; - std::vector out_wrapper{tensor_maybe_partial}; - return Collective( - in_wrapper, - out_wrapper, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - phi::ccl::CCLComm comm, - const phi::stream::Stream& stream) { - phi::DeviceManager::CCLSend(device_type_, - input.data(), - input.numel(), - phi::ccl::ToCCLDataType(input.dtype()), - dst_rank, - comm, - stream); + + return RunFnInXCCLEnv( + [&](const phi::stream::Stream& stream) { + auto comm_context = this->GetCommContext(); + comm_context->Send(tensor_maybe_partial, + tensor_maybe_partial.numel(), + dst_rank, + stream); }, + tensor_maybe_partial, CommType::SEND, sync_op, use_calc_stream); } -std::shared_ptr ProcessGroupCustom::Send( - std::vector& tensors, int dst_rank) { - CheckTensorsInDifferentCustomDevices(tensors, static_cast(GetSize())); - return Collective( - tensors, - tensors, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - phi::ccl::CCLComm comm, - const phi::stream::Stream& stream) { - phi::DeviceManager::CCLSend(device_type_, - input.data(), - input.numel(), - phi::ccl::ToCCLDataType(input.dtype()), - dst_rank, - comm, - stream); - }, - CommType::SEND, - false, - false); +std::shared_ptr ProcessGroupCustom::CreateTask( + const Place& place, + int rank, + CommType comm_type, + bool is_sync, + bool use_calc_stream) { + return std::make_shared( + place, rank, comm_type, is_sync, use_calc_stream); } -std::shared_ptr ProcessGroupCustom::Reduce( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const ReduceOptions& opts, +void ProcessGroupCustom::BroadcastUniqueXCCLID( + phi::ccl::CCLRootId* xccl_root_id) { + const std::string key = + "ProcessGroupCustom/xccl_ids/" + std::to_string(gid_) + "/0"; + if (rank_ == 0) { + store_->set(key, *xccl_root_id); + } else { + *xccl_root_id = store_->get(key); + } +} + +void ProcessGroupCustom::CreateXCCLEnvCache(const Place& place, + const std::string& place_key) { + if (!place_to_comm_ctx_.empty()) { + VLOG(3) << "Warning: Tensors from multiple devices are not supported yet."; + } + + VLOG(3) << "init xccl rank: " << rank_ << ", nranks: " << size_ + << ", place: " << place_key; + + phi::distributed::CommContextManager::CreateXCCLCommContext( + store_, std::to_string(gid_), place.GetDeviceType(), rank_, size_); + + auto* calc_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + auto comm_ctx = std::make_unique(place); + comm_ctx->SetAllocator( + &(phi::DeviceContextPool::Instance().Get(place)->GetAllocator())); + comm_ctx->SetHostAllocator( + &(phi::DeviceContextPool::Instance().Get(place)->GetHostAllocator())); + comm_ctx->SetZeroAllocator( + &(phi::DeviceContextPool::Instance().Get(place)->GetZeroAllocator())); + comm_ctx->SetHostZeroAllocator( + &(phi::DeviceContextPool::Instance().Get(place)->GetHostZeroAllocator())); + + auto xccl_comm_ctx = this->GetCommContext(); + comm_ctx->set_xccl_comm(xccl_comm_ctx->GetXcclComm()); + + auto xccl_event = std::make_unique(); + xccl_event->Init(place); + place_to_calc_event_.emplace(place_key, std::move(xccl_event)); + place_to_calc_ctx_.emplace(place_key, calc_ctx); + place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx)); + + // TODO(sunyilun): for compatibility, will be removed later + std::vector comm_ctx_wrapper{ + place_to_comm_ctx_[place_key].get()}; + places_to_ctx_.emplace(place_key, comm_ctx_wrapper); +} + +void ProcessGroupCustom::SyncCalcStream(const Place& place) { + const std::string& key = GetKeyFromPlace(place); + auto& calc_event = place_to_calc_event_.at(key); + const auto* calc_ctx = place_to_calc_ctx_.at(key); + const auto* comm_ctx = place_to_comm_ctx_.at(key).get(); + calc_event->Record(calc_ctx->GetStream().get()); + comm_ctx->GetStream()->WaitEvent(calc_event.get()); +} + +std::shared_ptr ProcessGroupCustom::RunFnInXCCLEnv( + std::function fn, + const phi::DenseTensor& tensor, + CommType comm_type, bool sync_op, bool use_calc_stream) { - phi::distributed::CommStaticCheck::SameShape(*out_tensor, - in_tensor, - /*dst_rank*/ opts.root_rank, - /*cur_rank*/ rank_, - size_, - phi::AllocationType::CUSTOM); - std::vector in_wrapper{in_tensor}; - std::vector out_wrapper{*out_tensor}; + const auto& place = tensor.place(); + const auto& key = GetKeyFromPlace(place); + + phi::DeviceGuard guard(place); + + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { + CreateXCCLEnvCache(place, key); + } + + if (!use_calc_stream) { + SyncCalcStream(place); + } + + auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream); + + const auto* calc_ctx = place_to_calc_ctx_.at(key); + const auto& comm_ctx = place_to_comm_ctx_.at(key); + auto& xccl_stream = + use_calc_stream ? *calc_ctx->GetStream() : *comm_ctx->GetStream(); + fn(xccl_stream); + + if (!use_calc_stream) { + if (FLAGS_use_stream_safe_cuda_allocator) { + memory::RecordStream(tensor.Holder(), xccl_stream.raw_stream()); + } + task->UpdateWaitChain(*comm_ctx); + } + + return task; +} + +// TODO(sunyilun): methods below will be removed later +void SyncDefaultStream(const std::vector& places, + phi::event::Event& xccl_event, // NOLINT + std::vector& dev_ctx) { // NOLINT + for (size_t i = 0; i < places.size(); ++i) { + auto* default_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(places[i])); + xccl_event.Record(default_ctx->GetStream().get()); + dev_ctx[i]->GetStream()->WaitEvent(&xccl_event); + } +} + +std::shared_ptr ProcessGroupCustom::CreateTask( + std::vector places, + int rank, + CommType comm_type, + const std::vector& inputs) { + return std::make_shared( + places, rank, comm_type, inputs); +} + +ProcessGroupCustom::XCCLTask::XCCLTask( + const std::vector& places, + int rank, + CommType CommType, + const std::vector& inputs) + : TaskStream(rank, inputs, CommType), task_place_(places[0]) { + comm_event_.Init(places[0]); +} + +// create XCCLManager cache for places_key +void ProcessGroupCustom::CreateXCCLManagerCache( + const std::string& places_key, const std::vector& places) { + PADDLE_ENFORCE_EQ(places_key.empty(), + false, + phi::errors::PreconditionNotMet( + "Not able to create/get the XCCL Communicator since " + "the CustomPlace are not known")); + + phi::ccl::CCLRootId xccl_root_id; + if (rank_ == 0) { + phi::DeviceManager::CCLGetUniqueId(device_type_, &xccl_root_id); + } + BroadcastUniqueXCCLID(&xccl_root_id); + + VLOG(3) << "init xccl rank: " << rank_ << ", nranks: " << size_ + << ", place: " << places_key << ", xccl uniqueid: " + << phi::ccl::SerializeXCCLUniqueId(xccl_root_id); + + std::vector> dev_ctx; + dev_ctx.resize(places.size()); + + std::vector dev_ctx_raw; + dev_ctx_raw.resize(places.size()); + + GroupStart(device_type_); + + for (size_t i = 0; i < places.size(); ++i) { + phi::DeviceGuard guard(places[i]); + + dev_ctx[i] = std::make_unique(places[i]); + dev_ctx[i]->SetAllocator( + &(phi::DeviceContextPool::Instance().Get(places[i])->GetAllocator())); + dev_ctx[i]->SetHostAllocator(&( + phi::DeviceContextPool::Instance().Get(places[i])->GetHostAllocator())); + dev_ctx[i]->SetZeroAllocator(&( + phi::DeviceContextPool::Instance().Get(places[i])->GetZeroAllocator())); + dev_ctx[i]->SetHostZeroAllocator(&(phi::DeviceContextPool::Instance() + .Get(places[i]) + ->GetHostZeroAllocator())); + + phi::ccl::CCLComm xccl_comm; + phi::DeviceManager::CCLCommInitRank( + device_type_, GetSize(), &xccl_root_id, GetRank(), &xccl_comm); + + dev_ctx[i]->set_xccl_comm(xccl_comm); + dev_ctx_raw[i] = dev_ctx[i].get(); + } + + GroupEnd(device_type_); + + // TODO(sunyilun): for compatibility, will be removed later + auto xccl_event = std::make_unique(); + xccl_event->Init(places[0]); + place_to_calc_event_.emplace(places_key, std::move(xccl_event)); + place_to_calc_ctx_.emplace( + places_key, + static_cast( + platform::DeviceContextPool::Instance().Get(places[0]))); + place_to_comm_ctx_.emplace(places_key, std::move(dev_ctx[0])); + + // These caches will be useful to process sync/wait/communicate + places_to_ctx_.emplace(places_key, std::move(dev_ctx_raw)); +} + +template +std::shared_ptr ProcessGroupCustom::Collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + CommType op_type) { + const auto places = GetPlaceList(inputs); + const auto key = GetKeyFromPlaces(places); + + { + std::lock_guard lock(mutex_); + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { + CreateXCCLManagerCache(key, places); + } + } + + SyncDefaultStream( + places, *place_to_calc_event_.at(key), places_to_ctx_.at(key)); + + auto task = CreateTask(places, rank_, op_type, inputs); + + // construct uninitialize guard for device + { + GroupStart(device_type_); + for (size_t i = 0; i < inputs.size(); ++i) { + phi::DeviceGuard guard(places[i]); + const auto& xccl_stream = *places_to_ctx_.at(key)[i]->GetStream(); + fn(inputs[i], + outputs[i], + places_to_ctx_.at(key)[i]->xccl_comm(), + xccl_stream); + } + GroupEnd(device_type_); + } + + if (FLAGS_use_stream_safe_cuda_allocator) { + for (size_t i = 0; i < inputs.size(); ++i) { + phi::DeviceGuard guard(places[i]); + memory::RecordStream(inputs[i].Holder(), + places_to_ctx_.at(key)[i]->stream()); + } + } + + for (size_t i = 0; i < inputs.size(); ++i) { + phi::DeviceGuard guard(places[i]); + task->UpdateWaitChain(*places_to_ctx_.at(key)[i]); + } + return task; +} + +template +std::shared_ptr ProcessGroupCustom::PointToPoint( + std::vector& tensors, + Fn fn, + int dst_rank, + CommType op_type) { + const auto places = GetPlaceList(tensors); + const auto key = GetKeyFromPlaces(places); + + { + std::lock_guard lock(mutex_); + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { + CreateXCCLManagerCache(key, places); + } + } + + SyncDefaultStream( + places, *place_to_calc_event_.at(key), places_to_ctx_.at(key)); + + auto task = CreateTask(places, rank_, op_type, tensors); + + // construct uninitialize guard for device + + { + GroupStart(device_type_); + for (size_t i = 0; i < tensors.size(); ++i) { + phi::DeviceGuard guard(places[i]); + + const auto& xccl_stream = *places_to_ctx_.at(key)[i]->GetStream(); + fn(tensors[i], + places_to_ctx_.at(key)[i]->xccl_comm(), + xccl_stream, + dst_rank); + } + GroupEnd(device_type_); + } + + if (FLAGS_use_stream_safe_cuda_allocator) { + for (size_t i = 0; i < tensors.size(); ++i) { + phi::DeviceGuard guard(places[i]); + memory::RecordStream(tensors[i].Holder(), + places_to_ctx_.at(key)[i]->stream()); + } + } + + for (size_t i = 0; i < tensors.size(); ++i) { + phi::DeviceGuard guard(places[i]); + task->UpdateWaitChain(*places_to_ctx_.at(key)[i]); + } + return task; +} + +std::shared_ptr ProcessGroupCustom::AllReduce( + std::vector& in_tensors, + std::vector& out_tensors, + const AllreduceOptions& opts) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(in_tensors, device_type_), + true, + phi::errors::InvalidArgument("All inputs should be in CustomPlace.")); return Collective( - in_wrapper, - out_wrapper, - [&](phi::DenseTensor& input, + in_tensors, + out_tensors, + [&](const phi::DenseTensor& input, phi::DenseTensor& output, - phi::ccl::CCLComm comm, + const phi::ccl::CCLComm& comm, const phi::stream::Stream& stream) { - phi::DeviceManager::CCLReduce(device_type_, - input.data(), - output.data(), - input.numel(), - phi::ccl::ToCCLDataType(input.dtype()), - ToCustomCCLRedType(opts.reduce_op), - opts.root_rank, - comm, - stream); + auto comm_context = this->GetCommContext(); + comm_context->AllReduce( + &output, + input, + paddle::distributed::ToXCCLRedType(opts.reduce_op), + stream); }, - CommType::REDUCE, - sync_op, - use_calc_stream); + CommType::ALLREDUCE); } -std::shared_ptr ProcessGroupCustom::Reduce( +std::shared_ptr ProcessGroupCustom::Broadcast( std::vector& in_tensors, std::vector& out_tensors, - const ReduceOptions& opts) { + const BroadcastOptions& opts) { PADDLE_ENFORCE_EQ( CheckTensorsInCustomPlace(in_tensors, device_type_), true, phi::errors::InvalidArgument("All inputs should be in CustomPlace.")); + return Collective( in_tensors, out_tensors, [&](phi::DenseTensor& input, phi::DenseTensor& output, - phi::ccl::CCLComm comm, + const phi::ccl::CCLComm& comm, const phi::stream::Stream& stream) { - phi::DeviceManager::CCLReduce(device_type_, - input.data(), - output.data(), - input.numel(), - phi::ccl::ToCCLDataType(input.dtype()), - ToCustomCCLRedType(opts.reduce_op), - opts.root_rank, - comm, - stream); + const auto root = + opts.source_rank * in_tensors.size() + opts.source_root; + auto comm_context = this->GetCommContext(); + comm_context->Broadcast(&output, input, root, stream); }, - CommType::REDUCE, + CommType::BROADCAST); +} + +void CheckTensorsInDifferentDevices( + const std::vector& tensors, const size_t num_devices) { + PADDLE_ENFORCE_EQ( + tensors.empty(), false, - false); + phi::errors::InvalidArgument("Tensor list must be nonempty.")); + PADDLE_ENFORCE_LE( + tensors.size(), + num_devices, + phi::errors::InvalidArgument("Tensor list mustn't be larger than the " + "number of available CustomDevices.")); + + std::set used_devices; + + for (const auto& t : tensors) { + PADDLE_ENFORCE_EQ(platform::is_custom_place(t.place()), + true, + phi::errors::InvalidArgument( + "Tensors must be CustomDevice and dense tensor.")); + + const auto inserted = used_devices.insert(t.place()).second; + PADDLE_ENFORCE_EQ(inserted, + true, + phi::errors::InvalidArgument( + "Tensors must be on distinct CustomDevice devices.")); + } } -std::shared_ptr ProcessGroupCustom::AllToAll( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const std::vector& out_size_each_rank, - const std::vector& in_size_each_rank, - bool sync_op, - bool use_calc_stream) { - const phi::DDim& out_dim = out_tensor->dims(); - const phi::DDim& in_dim = in_tensor.dims(); - CheckSizeOnEachRank(out_dim, out_size_each_rank, size_); - CheckSizeOnEachRank(in_dim, in_size_each_rank, size_); +std::shared_ptr ProcessGroupCustom::Send( + std::vector& tensors, int dst_rank) { + CheckTensorsInDifferentDevices(tensors, static_cast(GetSize())); - // NOTE: Since `all_to_all` needs other processes' participation, it cannot - // simply be covered by static checks. Factors are set to 0 here to skip the - // shape check. Its shape check will be done by dynamic checks with - // FLAGS_enable_nccl_dynamic_check. - phi::distributed::CommStaticCheck::CheckShape(*out_tensor, - in_tensor, - /*dst_rank*/ rank_, - /*cur_rank*/ rank_, - size_, - /*out_size_factor*/ 0, - /*in_size_factor*/ 0, - phi::AllocationType::CUSTOM); - - std::vector in_wrapper{in_tensor}; - std::vector out_wrapper{*out_tensor}; - return Collective( - in_wrapper, - out_wrapper, + auto task = PointToPoint( + tensors, [&](phi::DenseTensor& input, - phi::DenseTensor& output, - phi::ccl::CCLComm comm, - const phi::stream::Stream& stream) { - int64_t in_row_size = in_tensor.numel() / in_dim[0], - out_row_size = out_tensor->numel() / out_dim[0]; - int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0; - phi::DenseTensor input_partial, output_partial; - std::vector send_buf, recv_buf; - std::vector send_count, recv_count; - std::vector send_dtype, recv_dtype; + const phi::ccl::CCLComm& comm, + const phi::stream::Stream& stream, + int dst_rank) { + auto comm_context = this->GetCommContext(); + comm_context->Send(input, input.numel(), dst_rank, stream); + }, + dst_rank, + CommType::SEND); + return task; +} - for (auto i = 0; i < size_; i++) { - in_numel = in_size_each_rank[i] * in_row_size; - input_partial = GetPartialTensor(in_tensor, in_offset, in_numel); - out_numel = out_size_each_rank[i] * out_row_size; - output_partial = GetPartialTensor(*out_tensor, out_offset, out_numel); - in_offset += in_numel; - out_offset += out_numel; - send_buf.push_back(input_partial.data()); - recv_buf.push_back(output_partial.data()); - send_count.push_back(in_numel); - recv_count.push_back(out_numel); - send_dtype.push_back(phi::ccl::ToCCLDataType(input_partial.dtype())); - recv_dtype.push_back(phi::ccl::ToCCLDataType(output_partial.dtype())); - } +std::shared_ptr ProcessGroupCustom::Recv( + std::vector& tensors, int src_rank) { + CheckTensorsInDifferentDevices(tensors, static_cast(GetSize())); - phi::DeviceManager::CCLAllToAll( - device_type_, - const_cast(send_buf.data()), - send_count.data(), - send_dtype.data(), - recv_buf.data(), - recv_count.data(), - recv_dtype.data(), - rank_, - size_, - comm, - stream); + auto task = PointToPoint( + tensors, + [&](phi::DenseTensor& output, + const phi::ccl::CCLComm& comm, + const phi::stream::Stream& stream, + int src_rank) { + auto comm_context = this->GetCommContext(); + comm_context->Recv(&output, output.numel(), src_rank, stream); }, - CommType::ALLTOALL, - sync_op, - use_calc_stream); + src_rank, + CommType::RECV); + return task; +} + +std::shared_ptr ProcessGroupCustom::AllGather( + std::vector& in_tensors, + std::vector& out_tensors) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(in_tensors, device_type_), + true, + phi::errors::InvalidArgument("All inputs should be in CustomPlace.")); + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(out_tensors, device_type_), + true, + phi::errors::InvalidArgument("All outputs should be in CustomPlace.")); + return Collective( + in_tensors, + out_tensors, + [&](const phi::DenseTensor& input, + phi::DenseTensor& output, + const phi::ccl::CCLComm& comm, + const phi::stream::Stream& stream) { + auto comm_context = this->GetCommContext(); + comm_context->AllGather(&output, input, stream); + }, + CommType::ALLGATHER); } std::shared_ptr ProcessGroupCustom::AllToAll( @@ -863,8 +919,10 @@ std::shared_ptr ProcessGroupCustom::AllToAll( out_tensors, [&](phi::DenseTensor& input, phi::DenseTensor& output, - phi::ccl::CCLComm comm, + const phi::ccl::CCLComm& comm, const phi::stream::Stream& stream) { + auto comm_context = this->GetCommContext(); + size_t offset = 0; std::vector send_buf, recv_buf; std::vector send_count(size_, input.numel() / size_), @@ -889,111 +947,35 @@ std::shared_ptr ProcessGroupCustom::AllToAll( recv_dtype.data(), rank_, size_, - comm, - stream); - }, - CommType::ALLTOALL, - false, - false); -} - -std::shared_ptr ProcessGroupCustom::ReduceScatter( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const ReduceScatterOptions& opts, - bool sync_op, - bool use_calc_stream) { - phi::distributed::CommStaticCheck::ScatterLikeShape( - *out_tensor, - in_tensor, - /*dst_rank*/ rank_, - /*cur_rank*/ rank_, - size_, - phi::AllocationType::CUSTOM); - std::vector in_wrapper{in_tensor}; - std::vector out_wrapper{*out_tensor}; - return Collective( - in_wrapper, - out_wrapper, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - phi::ccl::CCLComm comm, - const phi::stream::Stream& stream) { - phi::DeviceManager::CCLReduceScatter( - device_type_, - const_cast(in_tensor.data()), - out_tensor->data(), - out_tensor->numel(), - phi::ccl::ToCCLDataType(in_tensor.dtype()), - paddle::distributed::ToCustomCCLRedType(opts.reduce_op), - comm, + comm_context->GetXcclComm(), stream); }, - CommType::REDUCE_SCATTER, - false, - false); + CommType::ALLTOALL); } -std::shared_ptr ProcessGroupCustom::Scatter( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const ScatterOptions& opts, - bool sync_op, - bool use_calc_stream) { - phi::distributed::CommStaticCheck::ScatterLikeShape( - *out_tensor, - in_tensor, - /*dst_rank*/ opts.root_rank, - /*cur_rank*/ rank_, - size_, - phi::AllocationType::CUSTOM); - std::vector in_wrapper{in_tensor}; - std::vector out_wrapper{*out_tensor}; +std::shared_ptr ProcessGroupCustom::Reduce( + std::vector& in_tensors, + std::vector& out_tensors, + const ReduceOptions& opts) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCustomPlace(in_tensors, device_type_), + true, + phi::errors::InvalidArgument("All inputs should be in CustomPlace.")); return Collective( - in_wrapper, - out_wrapper, - [&](phi::DenseTensor& input, + in_tensors, + out_tensors, + [&](const phi::DenseTensor& input, phi::DenseTensor& output, - phi::ccl::CCLComm comm, + const phi::ccl::CCLComm& comm, const phi::stream::Stream& stream) { - int64_t numel = in_tensor.numel() / size_; - if (rank_ == opts.root_rank) { - int64_t offset = 0; - phi::DenseTensor partial_tensor; - for (auto i = 0; i < size_; i++) { - partial_tensor = GetPartialTensor(in_tensor, offset, numel); - if (i != rank_) { - phi::DeviceManager::CCLSend( - device_type_, - partial_tensor.data(), - numel, - phi::ccl::ToCCLDataType(partial_tensor.dtype()), - i, - comm, - stream); - } else { - phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace()) - ->MemoryCopyD2D(out_tensor->data(), - partial_tensor.data(), - numel * phi::SizeOf(partial_tensor.dtype()), - &stream); - } - offset += numel; - } - } else { - phi::DeviceManager::CCLRecv( - device_type_, - out_tensor->data(), - numel, - phi::ccl::ToCCLDataType(out_tensor->dtype()), - opts.root_rank, - comm, - stream); - } + auto comm_context = this->GetCommContext(); + comm_context->Reduce(&output, + input, + paddle::distributed::ToXCCLRedType(opts.reduce_op), + opts.root_rank, + stream); }, - CommType::SCATTER, - sync_op, - use_calc_stream); + CommType::REDUCE); } std::shared_ptr ProcessGroupCustom::Scatter( @@ -1003,134 +985,36 @@ std::shared_ptr ProcessGroupCustom::Scatter( PADDLE_ENFORCE_EQ( CheckTensorsInCustomPlace(in_tensors, device_type_), true, - phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); + phi::errors::InvalidArgument("All inputs should be in CustomPlace.")); PADDLE_ENFORCE_EQ( CheckTensorsInCustomPlace(out_tensors, device_type_), true, - phi::errors::InvalidArgument("All inputs should be in CudaPlace.")); + phi::errors::InvalidArgument("All inputs should be in CustomPlace.")); return Collective( in_tensors, out_tensors, [&](phi::DenseTensor& input, phi::DenseTensor& output, - phi::ccl::CCLComm comm, - const phi::stream::Stream& stream) { - int64_t numel = input.numel() / size_; - if (rank_ == opts.root_rank) { - int64_t offset = 0; - phi::DenseTensor partial_tensor; - for (auto i = 0; i < size_; i++) { - partial_tensor = GetPartialTensor(input, offset, numel); - if (i != rank_) { - phi::DeviceManager::CCLSend( - device_type_, - partial_tensor.data(), - numel, - phi::ccl::ToCCLDataType(partial_tensor.dtype()), - i, - comm, - stream); - } else { - phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace()) - ->MemoryCopyD2D(output.data(), - partial_tensor.data(), - numel * phi::SizeOf(partial_tensor.dtype()), - &stream); - } - offset += numel; - } - } else { - phi::DeviceManager::CCLRecv(device_type_, - output.data(), - numel, - phi::ccl::ToCCLDataType(output.dtype()), - opts.root_rank, - comm, - stream); - } - }, - CommType::SCATTER, - false, - false); -} - -std::shared_ptr ProcessGroupCustom::Gather( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const GatherOptions& opts, - bool sync_op, - bool use_calc_stream) { - std::vector partial_tensors; - if (rank_ == opts.root_rank) { - partial_tensors.reserve(size_); - size_t offset = 0; - size_t numel = out_tensor->numel() / size_; - for (auto i = 0; i < size_; i++) { - partial_tensors.push_back(GetPartialTensor(*out_tensor, offset, numel)); - offset += numel; - } - } - return Gather(&partial_tensors, in_tensor, opts, sync_op, use_calc_stream); -} - -std::shared_ptr ProcessGroupCustom::Gather( - std::vector* gather_tensors_ptr, - const phi::DenseTensor& in_tensor, - const GatherOptions& opts, - bool sync_op, - bool use_calc_stream) { - auto& gather_tensors = *gather_tensors_ptr; - PADDLE_ENFORCE_GT(size_, - opts.root_rank, - phi::errors::InvalidArgument( - "root world size [%d] is less than root rank [%d]", - size_, - opts.root_rank)); - std::vector in_wrapper{in_tensor}; - return Collective( - in_wrapper, - in_wrapper, - [&](phi::DenseTensor& input, - phi::DenseTensor& output, - phi::ccl::CCLComm comm, + const phi::ccl::CCLComm& comm, const phi::stream::Stream& stream) { - // root receive from all devices + auto comm_context = this->GetCommContext(); + size_t offset = 0; + size_t count = input.numel() / size_; if (rank_ == opts.root_rank) { + comm_context->GroupStart(); for (auto i = 0; i < size_; i++) { - auto& gather_tensor = gather_tensors[i]; - if (i != rank_) { - phi::DeviceManager::CCLRecv( - device_type_, - gather_tensor.data(), - gather_tensor.numel(), - phi::ccl::ToCCLDataType(gather_tensor.dtype()), - i, - comm, - stream); - } else { - phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace()) - ->MemoryCopyD2D( - gather_tensor.data(), - in_tensor.data(), - in_tensor.numel() * phi::SizeOf(in_tensor.dtype()), - &stream); - } + auto input_data = reinterpret_cast( + GetPointerByOffset(input.data(), offset, input.dtype())); + comm_context->Send(*input_data, count, i, stream); + offset += count; } + comm_context->Recv(&output, count, opts.root_rank, stream); + comm_context->GroupEnd(); } else { - // send to root - phi::DeviceManager::CCLSend( - device_type_, - const_cast(in_tensor.data()), - in_tensor.numel(), - phi::ccl::ToCCLDataType(in_tensor.dtype()), - opts.root_rank, - comm, - stream); + comm_context->Recv(&output, count, opts.root_rank, stream); } }, - CommType::GATHER, - sync_op, - use_calc_stream); + CommType::SCATTER); } std::shared_ptr @@ -1146,5 +1030,16 @@ ProcessGroupCustom::CreateProcessGroupCustom( return process_group; } +phi::distributed::XCCLCommContext* ProcessGroupCustom::GetCommContext() { + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + auto comm_context = static_cast( + comm_context_manager.Get(std::to_string(this->gid_))); + PADDLE_ENFORCE_NE(comm_context, + nullptr, + phi::errors::Unavailable("XCCLCommContext is nullptr")); + return comm_context; +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/process_group_custom.h b/paddle/fluid/distributed/collective/process_group_custom.h index 9e891df96f1e9adb85de8f4d00a2cc4ba94bb915..c60d185c9e480827b1d9b51f3510b922a3faeddc 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.h +++ b/paddle/fluid/distributed/collective/process_group_custom.h @@ -15,62 +15,57 @@ #pragma once #include -#include #include #include #include #include -#include "paddle/fluid/distributed/collective/custom_ccl_tools.h" #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group_with_stream.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/gen_comm_id_helper.h" -#include "paddle/fluid/platform/place.h" +#include "paddle/phi/backends/device_manager.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/distributed/store/store.h" +#include "paddle/phi/core/distributed/xccl_comm_context.h" namespace paddle { namespace distributed { -using Place = paddle::platform::Place; -using CustomDeviceContext = paddle::platform::CustomDeviceContext; -class ProcessGroupCustom : public ProcessGroupWithStream { +using Place = phi::Place; + +class ProcessGroupCustom final : public ProcessGroupWithStream { public: - class CustomTask : public ProcessGroup::Task, - public std::enable_shared_from_this { + class XCCLTask final : public ProcessGroupWithStream::TaskStream, + public std::enable_shared_from_this { public: - CustomTask(const std::vector& places, - int rank, - CommType CommType, - const std::vector& inputs); + XCCLTask(const Place& place, + int rank, + CommType comm_type, + bool sync_op, + bool use_calc_stream); + virtual ~XCCLTask(); bool IsCompleted() override; - void SynchronizeStreams(); bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override; void Synchronize() override; void UpdateWaitChain(const phi::DeviceContext& ctx) override; - void SetOutputs(std::vector& outputs); // NOLINT - virtual ~CustomTask(); - std::vector control_events_; - std::vector barrierTensors_; + bool IsBlockCPUInWait() const { return block_cpu_in_wait_; } + void SetBlockCPUInWait() { block_cpu_in_wait_ = true; } - protected: - std::vector places_; - std::vector> cclComms_; - std::shared_ptr> outputs_; + // TODO(sunyilun): methods below will be removed later + XCCLTask(const std::vector& places, + int rank, + CommType CommType, + const std::vector& inputs); private: - const std::string device_type_; + bool block_cpu_in_wait_{false}; + phi::event::Event comm_event_; // event on comm stream + Place task_place_; }; - ProcessGroupCustom(const std::shared_ptr& store, - const std::string& device_type, - int rank, - int size, - int gid); - + public: static std::shared_ptr CreateProcessGroupCustom( const std::shared_ptr& store, const std::string& device_type, @@ -78,19 +73,18 @@ class ProcessGroupCustom : public ProcessGroupWithStream { int size, int gid); - std::string GetBackendName() const override { return "XCCL_" + device_type_; } + ProcessGroupCustom(const std::shared_ptr& store, + const std::string& device_type, + int rank, + int size, + int gid); - std::shared_ptr Barrier( - const BarrierOptions& = BarrierOptions()) override; + std::string GetBackendName() const override { return "XCCL"; } phi::DeviceContext* GetDeviceContext(const Place& place) const override; - phi::ccl::CCLComm CustomCCLComm(const Place& place) const; - - // TODO(sunyilun): methods below will be removed later - std::shared_ptr AllGather( - std::vector& in_tensors, - std::vector& out_tensors) override; + phi::DeviceContext* GetDeviceContext(const Place& place, + bool use_calc_stream) const override; std::shared_ptr AllGather( phi::DenseTensor* out_tensor, @@ -100,11 +94,6 @@ class ProcessGroupCustom : public ProcessGroupWithStream { bool sync_op, bool use_calc_stream) override; - std::shared_ptr AllReduce( - std::vector& in_tensors, - std::vector& out_tensors, - const AllreduceOptions& = AllreduceOptions()) override; - std::shared_ptr AllReduce( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, @@ -112,10 +101,16 @@ class ProcessGroupCustom : public ProcessGroupWithStream { bool sync_op, bool use_calc_stream) override; - std::shared_ptr Broadcast( - std::vector& in_tensors, - std::vector& out_tensors, - const BroadcastOptions& = BroadcastOptions()) override; + std::shared_ptr AllToAll( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const std::vector& out_size_each_rank, + const std::vector& in_size_each_rank, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr Barrier( + const BarrierOptions& = BarrierOptions()) override; std::shared_ptr Broadcast( phi::DenseTensor* out_tensor, @@ -124,49 +119,12 @@ class ProcessGroupCustom : public ProcessGroupWithStream { bool sync_op, bool use_calc_stream) override; - std::shared_ptr Send(const phi::DenseTensor& tensor, - int dst_rank, - int64_t offset, - int64_t numel, - bool sync_op, - bool use_calc_stream) override; - - std::shared_ptr Send( - std::vector& tensors, int dst_rank) override; - - std::shared_ptr Recv(phi::DenseTensor* tensor, - int src_rank, - int64_t offset, - int64_t numel, - bool sync_op, - bool use_calc_stream) override; - - std::shared_ptr Recv( - std::vector& tensors, int src_rank) override; - std::shared_ptr Reduce(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, const ReduceOptions& opts, bool sync_op, bool use_calc_stream) override; - std::shared_ptr Reduce( - std::vector& tensors, - std::vector& out_tensors, - const ReduceOptions& opts) override; - - std::shared_ptr AllToAll( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - const std::vector& out_size_each_rank, - const std::vector& in_size_each_rank, - bool sync_op, - bool use_calc_stream) override; - - std::shared_ptr AllToAll( - std::vector& in_tensors, - std::vector& out_tensors) override; - std::shared_ptr ReduceScatter( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, @@ -180,11 +138,6 @@ class ProcessGroupCustom : public ProcessGroupWithStream { bool sync_op, bool use_calc_stream) override; - std::shared_ptr Scatter( - std::vector& in_tensors, - std::vector& out_tensors, - const ScatterOptions& opts) override; - std::shared_ptr Gather(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, const GatherOptions& opts, @@ -198,52 +151,124 @@ class ProcessGroupCustom : public ProcessGroupWithStream { bool sync_op, bool use_calc_stream) override; - protected: - virtual std::shared_ptr CreateTask( - std::vector places, - int rank, - CommType opType, - const std::vector& inputs); + std::shared_ptr Recv(phi::DenseTensor* tensor, + int src_rank, + int64_t offset, + int64_t numel, + bool sync_op, + bool use_calc_stream) override; - std::shared_ptr store_; - std::shared_ptr custom_comm_; - std::mutex mutex_; - std::unordered_map>> - places_to_customcomm_; - std::unordered_map> - places_to_events_; - std::unordered_map>> - places_to_ctx_; - std::set used_place_ids_; + std::shared_ptr Send(const phi::DenseTensor& tensor, + int dst_rank, + int64_t offset, + int64_t numel, + bool sync_op, + bool use_calc_stream) override; + + static void GroupStart(const std::string& dev_type); + + static void GroupEnd(const std::string& dev_type); + + phi::ccl::CCLComm XCCLComm(const Place& place) const; + + // TODO(liyurui): This API will be moved later + std::shared_ptr AllReduce( + std::vector& in_tensors, + std::vector& out_tensors, + const AllreduceOptions& = AllreduceOptions()) override; + + // TODO(sunyilun): methods below will be removed later + std::shared_ptr Broadcast( + std::vector& in_tensors, + std::vector& out_tensors, + const BroadcastOptions& = BroadcastOptions()) override; + + std::shared_ptr Send( + std::vector& tensors, int dst_rank) override; + + std::shared_ptr Recv( + std::vector& tensors, int src_rank) override; + + std::shared_ptr AllGather( + std::vector& in_tensors, + std::vector& out_tensors) override; + + std::shared_ptr AllToAll( + std::vector& in_tensors, + std::vector& out_tensors) override; + + std::shared_ptr Reduce( + std::vector& tensors, + std::vector& out_tensors, + const ReduceOptions& opts) override; + + std::shared_ptr Scatter( + std::vector& in_tensors, + std::vector& out_tensors, + const ScatterOptions& opts) override; private: - void BcastCustomId(std::vector& ccl_ids, // NOLINT - int root, - int server_fd); + std::shared_ptr CreateTask( + const Place& place, + int rank, + CommType op_type, + bool sync_op, + bool use_calc_stream); - void BroadcastUniqueCustomID( - std::vector& custom_ccl_ids); // NOLINT + void BroadcastUniqueXCCLID(phi::ccl::CCLRootId* nccl_id); + + void CreateXCCLEnvCache(const Place& place, const std::string& place_key); + + void SyncCalcStream(const Place& place); + + std::shared_ptr RunFnInXCCLEnv( + std::function fn, + const phi::DenseTensor& tensor, + CommType comm_type, + bool sync_op, + bool use_calc_stream); + + // TODO(sunyilun): methods below will be removed later + std::shared_ptr CreateTask( + std::vector places, + int rank, + CommType op_type, + const std::vector& inputs); template std::shared_ptr Collective( std::vector& inputs, // NOLINT std::vector& outputs, // NOLINT Fn fn, - CommType op_type, - bool sync_op, - bool use_calc_stream); + CommType op_type); template - std::shared_ptr Collective(Fn fn, - CommType op_type, - bool sync_op, - bool use_calc_stream); - - void CreateCustomManagerCache(const std::string& places_key, - const std::vector& places); - const std::string device_type_; + std::shared_ptr PointToPoint( + std::vector& tensors, // NOLINT + Fn fn, + int dst_rank, + CommType op_type); + + void CreateXCCLManagerCache(const std::string& places_key, + const std::vector& places); + + phi::distributed::XCCLCommContext* GetCommContext(); + + private: + std::shared_ptr store_; + std::string device_type_; + + std::unordered_map> + place_to_calc_event_; // event on calc stream + std::unordered_map place_to_calc_ctx_; + std::unordered_map> + place_to_comm_ctx_; + + // TODO(sunyilun): attrs below will be removed later + std::mutex mutex_; + std::unordered_map> + places_to_ctx_; }; + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/processgroup_comm_utils.cc b/paddle/fluid/distributed/collective/processgroup_comm_utils.cc index 26171108459284415048593d6120c68f25bbf0f6..94723906fccb13a73ac9ee5f5090c7bfb08677de 100644 --- a/paddle/fluid/distributed/collective/processgroup_comm_utils.cc +++ b/paddle/fluid/distributed/collective/processgroup_comm_utils.cc @@ -53,8 +53,8 @@ ccl::CCLComm GetCCLComm(const Place& place, int global_gid) { #endif } else if (place.GetType() == phi::AllocationType::CUSTOM) { #if defined(PADDLE_WITH_CUSTOM_DEVICE) - return static_cast(pg) - ->CustomCCLComm(place); + return static_cast(pg)->XCCLComm( + place); #else return nullptr; #endif diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 0800f091837325164b4127a357b11368bd325bbd..7add694a04f68fe5e01fba7b4a451dd9ff42a600 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -1574,6 +1574,7 @@ void AllocatorFacade::SetDefaultStream(const platform::CustomPlace& place, m_->SetDefaultStream(place, stream); } } + #endif UNUSED static std::shared_ptr unused_obj = diff --git a/paddle/fluid/memory/allocation/allocator_facade.h b/paddle/fluid/memory/allocation/allocator_facade.h index 20be4c11bfe2d53d7fdfd34639b27886a990dddc..0131d56c6f6428e705f0a3f48c2f88cd062d1001 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.h +++ b/paddle/fluid/memory/allocation/allocator_facade.h @@ -101,6 +101,7 @@ class AllocatorFacade { phi::stream::stream_t stream); void RecordStream(std::shared_ptr allocation, phi::stream::stream_t stream); + void SetDefaultStream(const platform::CustomPlace& place, phi::stream::stream_t stream); #endif diff --git a/paddle/fluid/memory/malloc.cc b/paddle/fluid/memory/malloc.cc index 46f9b1189cb6853f75297d634794047b5533f972..c888a1d6041112e927765c3780365b542b47ae66 100644 --- a/paddle/fluid/memory/malloc.cc +++ b/paddle/fluid/memory/malloc.cc @@ -71,6 +71,14 @@ gpuStream_t GetStream(const std::shared_ptr& allocation) { return allocation::AllocatorFacade::Instance().GetStream(allocation); } +#endif + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +void RecordStream(std::shared_ptr allocation, + phi::stream::stream_t stream) { + return allocation::AllocatorFacade::Instance().RecordStream(allocation, + stream); +} #endif } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/malloc.h b/paddle/fluid/memory/malloc.h index b8f5f0289c4bcd38cf5c6a15e9a0596b0aae9de6..de9c9f7367367c1e2938d2d4e3fe0fb7753e0e75 100644 --- a/paddle/fluid/memory/malloc.h +++ b/paddle/fluid/memory/malloc.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/platform/place.h" +#include "paddle/phi/backends/device_manager.h" #include "paddle/phi/core/device_context.h" #include "paddle/phi/core/stream.h" @@ -55,5 +56,9 @@ void RecordStream(std::shared_ptr allocation, gpuStream_t stream); gpuStream_t GetStream(const std::shared_ptr& allocation); #endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE +void RecordStream(std::shared_ptr allocation, + phi::stream::stream_t stream); +#endif } // namespace memory } // namespace paddle diff --git a/paddle/fluid/pybind/communication.cc b/paddle/fluid/pybind/communication.cc index 76877cdfae741e159c2604670afea4137c26f6f4..bf58f1d6ac0d84006c4fc708645e1a2e801a0a14 100644 --- a/paddle/fluid/pybind/communication.cc +++ b/paddle/fluid/pybind/communication.cc @@ -56,6 +56,12 @@ void BindCommContextManager(py::module *m) { "create_gloo_comm_context", &phi::distributed::CommContextManager::CreateGlooCommContext, py::call_guard()) +#endif +#if defined(PADDLE_WITH_CUSTOM_DEVICE) + .def_static( + "create_xccl_comm_context", + &phi::distributed::CommContextManager::CreateXCCLCommContext, + py::call_guard()) #endif .def("set_store", &phi::distributed::CommContextManager::SetStore); } diff --git a/paddle/phi/backends/c_comm_lib.h b/paddle/phi/backends/c_comm_lib.h index 31b84e34d38368944a1aca31967b926763eb4edc..30ebe6d2fa49611e849f1efd7a4301999c2656ac 100644 --- a/paddle/phi/backends/c_comm_lib.h +++ b/paddle/phi/backends/c_comm_lib.h @@ -21,6 +21,8 @@ #include "paddle/phi/core/errors.h" #include "paddle/phi/core/macros.h" +#include "paddle/phi/common/reduce_type.h" + namespace phi { namespace ccl { typedef void* CCLComm; @@ -38,6 +40,32 @@ enum CCLDataType { CCL_DATA_TYPE_UINT8 }; +inline CCLReduceOp ToXCCLReduceOp(int reduce_type) { + phi::ccl::CCLReduceOp red_type = phi::ccl::CCLReduceOp::SUM; + switch (static_cast(reduce_type)) { + case phi::ReduceType::kRedSum: + red_type = phi::ccl::CCLReduceOp::SUM; + break; + case phi::ReduceType::kRedMax: + red_type = phi::ccl::CCLReduceOp::MAX; + break; + case phi::ReduceType::kRedMin: + red_type = phi::ccl::CCLReduceOp::MIN; + break; + case phi::ReduceType::kRedProd: + red_type = phi::ccl::CCLReduceOp::PRODUCT; + break; + case phi::ReduceType::kRedAvg: + red_type = phi::ccl::CCLReduceOp::AVG; + break; + default: + PADDLE_THROW( + errors::Unavailable("Unsuppored reduce type. Reduce type must be one " + "of SUM, MAX, MIN, PRODUCT and AVG.")); + } + return red_type; +} + inline CCLDataType ToCCLDataType(phi::DataType type) { if (type == phi::DataType::FLOAT64) { return CCL_DATA_TYPE_FP64; @@ -79,5 +107,14 @@ inline phi::DataType ToPhiDataType(CCLDataType type) { } } +inline std::string SerializeXCCLUniqueId(const phi::ccl::CCLRootId& ccl_id) { + const uint8_t* bytes = ccl_id.data(); + std::ostringstream oss; + for (size_t i = 0; i < ccl_id.size(); ++i) { + oss << std::hex << static_cast(bytes[i]); + } + return oss.str(); +} + } // namespace ccl } // namespace phi diff --git a/paddle/phi/backends/custom/custom_device.cc b/paddle/phi/backends/custom/custom_device.cc index 8d5b44998e9fe60dbfc265fffe9bef200bcc6df0..ef91c0c9f65d3742dccd64f959695abc8d17600c 100644 --- a/paddle/phi/backends/custom/custom_device.cc +++ b/paddle/phi/backends/custom/custom_device.cc @@ -754,13 +754,15 @@ class CustomDevice : public DeviceInterface { } void CCLGroupStart() override { - CHECK_PTR(pimpl_->xccl_group_start); - PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_group_start()); + if (pimpl_->xccl_group_start) { + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_group_start()); + } } void CCLGroupEnd() override { - CHECK_PTR(pimpl_->xccl_group_end); - PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_group_end()); + if (pimpl_->xccl_group_end) { + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_group_end()); + } } void CCLSend(void* send_buf, diff --git a/paddle/phi/core/distributed/CMakeLists.txt b/paddle/phi/core/distributed/CMakeLists.txt index e759b7d9c8da9511bb66e7acc8a0e30a7276e194..12c59059c7c3229a1d2cd5d1a80f3adc0a4e208a 100644 --- a/paddle/phi/core/distributed/CMakeLists.txt +++ b/paddle/phi/core/distributed/CMakeLists.txt @@ -12,4 +12,8 @@ if(WITH_GLOO) list(APPEND DISTRIBUTED_COMMON_SRCS gloo_utils.cc gloo_comm_context.cc) endif() +if(WITH_CUSTOM_DEVICE) + list(APPEND DISTRIBUTED_COMMON_SRCS xccl_comm_context.cc) +endif() + collect_srcs(core_srcs SRCS ${DISTRIBUTED_COMMON_SRCS}) diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.cc b/paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.cc index 63737d2d1b03ea47aa5a13743b6bc5b7418deba4..c1e1421ab730291f5c6fd6d5fe946075b8d4a079 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.cc @@ -53,6 +53,18 @@ DenseTensor ReshardAllGatherFunctor(DeviceContext* dev_ctx, })); return out; } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + if (phi::CustomContext::classof(dev_ctx)) { + PD_VISIT_FLOATING_AND_INTEGRAL_TYPES( + input.dtype(), "AllGather", ([&] { + AllGather(static_cast(*dev_ctx), + input, + world_size, + &out); + })); + return out; + } #endif PADDLE_THROW(phi::errors::Unimplemented( "The all_gather in reshard only supported on CPU and GPU for now.")); diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc index 10d0272209ddfa996325c67039e74eda6573f405..60b60ab9421fcaf4407a00756f89175c73edf6a0 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc @@ -170,6 +170,15 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx, #else PADDLE_THROW(phi::errors::Unimplemented( "Cannot use gloo on CPU, please turn PADDLE_WITH_GLOO flag on.")); +#endif + } else if (phi::CustomContext::classof(&dev_ctx)) { +#ifdef PADDLE_WITH_CUSTOM_DEVICE + CommContextManager::CreateXCCLCommContext( + store, + unique_comm_key, + dev_ctx.GetPlace().GetDeviceType(), + rank, + world_size); #endif } else { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) diff --git a/paddle/phi/core/distributed/comm_context.h b/paddle/phi/core/distributed/comm_context.h index 7fad9b6e24925c68b7b65291291e1f39f7729b44..173ff6f8673d486e6872ed0149f5ee78d14c9faa 100644 --- a/paddle/phi/core/distributed/comm_context.h +++ b/paddle/phi/core/distributed/comm_context.h @@ -24,8 +24,8 @@ class CommContext { CommContext(int rank, int size) : rank_(rank), size_(size) {} virtual ~CommContext() = default; - int GetRank() { return rank_; } - int GetSize() { return size_; } + int GetRank() const { return rank_; } + int GetSize() const { return size_; } protected: int rank_; diff --git a/paddle/phi/core/distributed/comm_context_manager.cc b/paddle/phi/core/distributed/comm_context_manager.cc index 89b98837834a7ce34bef5450574ee426c001d636..3154b44c5ecf13e17e1a46e8073da4073cbd1e9d 100644 --- a/paddle/phi/core/distributed/comm_context_manager.cc +++ b/paddle/phi/core/distributed/comm_context_manager.cc @@ -32,6 +32,9 @@ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/phi/core/distributed/nccl_comm_context.h" #endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE +#include "paddle/phi/core/distributed/xccl_comm_context.h" +#endif namespace phi { namespace distributed { @@ -91,6 +94,35 @@ void CommContextManager::CreateGlooCommContext( } #endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE +void CommContextManager::CreateXCCLCommContext( + const std::shared_ptr& store, + const std::string& unique_comm_key, + const std::string& device_type, + int rank, + int size) { + phi::ccl::CCLRootId xccl_root_id; + if (rank == 0) { + phi::DeviceManager::CCLGetUniqueId(device_type, &xccl_root_id); + } + + std::string unique_key = "XCCLCommContext/" + unique_comm_key; + if (rank == 0) { + store->set(unique_key, xccl_root_id); + } else { + xccl_root_id = store->get(unique_key); + } + VLOG(3) << "init xccl rank: " << rank << ", nranks: " << size + << ", unique_comm_key: " << unique_comm_key << ", xccl uniqueid: " + << phi::ccl::SerializeXCCLUniqueId(xccl_root_id); + auto xccl_comm_context = + std::make_unique(device_type, rank, size, xccl_root_id); + auto& comm_context_manager = CommContextManager::GetInstance(); + comm_context_manager.SetStore(store); + comm_context_manager.Emplace(unique_comm_key, std::move(xccl_comm_context)); +} +#endif + CommContext* CommContextManager::Emplace( const std::string& unique_comm_key, std::unique_ptr comm_context) { diff --git a/paddle/phi/core/distributed/comm_context_manager.h b/paddle/phi/core/distributed/comm_context_manager.h index 31fc0d43b2535b1eb037e1e67dfeb6e63cf469dd..6d82e89f92ba0282d58b96f39b179528474fbf03 100644 --- a/paddle/phi/core/distributed/comm_context_manager.h +++ b/paddle/phi/core/distributed/comm_context_manager.h @@ -62,6 +62,14 @@ class CommContextManager { int size); #endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + static void CreateXCCLCommContext(const std::shared_ptr& store, + const std::string& unique_comm_key, + const std::string& device_type, + int rank, + int size); +#endif + private: DISABLE_COPY_AND_ASSIGN(CommContextManager); diff --git a/paddle/phi/core/distributed/xccl_comm_context.cc b/paddle/phi/core/distributed/xccl_comm_context.cc new file mode 100644 index 0000000000000000000000000000000000000000..6342ff27a487228a8a0e87a0b898e4fec744f5d4 --- /dev/null +++ b/paddle/phi/core/distributed/xccl_comm_context.cc @@ -0,0 +1,194 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/xccl_comm_context.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/distributed/check/static_check.h" +#include "paddle/phi/core/distributed/utils.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/utils/data_type.h" + +namespace phi { +namespace distributed { + +XCCLCommContext::XCCLCommContext(const std::string& device_type, + int rank, + int size, + const ccl::CCLRootId& xccl_id) + : CommContext(rank, size) { + device_type_ = device_type; + phi::DeviceManager::CCLCommInitRank(device_type, + size_, + const_cast(&xccl_id), + rank, + &xccl_comm_); +} + +void XCCLCommContext::Broadcast(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int root, + const phi::stream::Stream& stream) const { + CommStaticCheck::SameShape(*out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_, + phi::AllocationType::CUSTOM); + if (rank_ == root) { + phi::DeviceManager::CCLBroadcast(device_type_, + const_cast(in_tensor.data()), + in_tensor.numel(), + phi::ccl::ToCCLDataType(in_tensor.dtype()), + root, + xccl_comm_, + stream); + } else { + phi::DeviceManager::CCLBroadcast(device_type_, + out_tensor->data(), + out_tensor->numel(), + phi::ccl::ToCCLDataType(in_tensor.dtype()), + root, + xccl_comm_, + stream); + } +} + +void XCCLCommContext::AllGather(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const phi::stream::Stream& stream) const { + phi::distributed::CommStaticCheck::GatherLikeShape( + *out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_, + phi::AllocationType::CUSTOM); + phi::DeviceManager::CCLAllGather(device_type_, + const_cast(in_tensor.data()), + out_tensor->data(), + in_tensor.numel(), + phi::ccl::ToCCLDataType(in_tensor.dtype()), + xccl_comm_, + stream); +} +void XCCLCommContext::ReduceScatter(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + phi::ccl::CCLReduceOp reduce_type, + const phi::stream::Stream& stream) const { + phi::distributed::CommStaticCheck::ScatterLikeShape( + *out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_, + phi::AllocationType::CUSTOM); + phi::DeviceManager::CCLReduceScatter( + device_type_, + const_cast(in_tensor.data()), + out_tensor->data(), + out_tensor->numel(), + phi::ccl::ToCCLDataType(in_tensor.type()), + reduce_type, + xccl_comm_, + stream); +} + +void XCCLCommContext::Send(const phi::DenseTensor& in_tensor, + const int64_t& count, + const int& peer, + const phi::stream::Stream& stream) const { + phi::distributed::CommStaticCheck::CheckShape( + in_tensor, rank_, size_, phi::AllocationType::CUSTOM); + phi::DeviceManager::CCLSend(device_type_, + const_cast(in_tensor.data()), + count, + phi::ccl::ToCCLDataType(in_tensor.type()), + peer, + xccl_comm_, + stream); + VLOG(3) << "rank " << GetRank() << " send " << phi::product(in_tensor.dims()) + << " to " << peer; +} + +void XCCLCommContext::Recv(phi::DenseTensor* out_tensor, + const int64_t& count, + const int& peer, + const phi::stream::Stream& stream) const { + phi::distributed::CommStaticCheck::CheckShape( + *out_tensor, rank_, size_, phi::AllocationType::CUSTOM); + phi::DeviceManager::CCLRecv(device_type_, + out_tensor->data(), + count, + phi::ccl::ToCCLDataType(out_tensor->type()), + peer, + xccl_comm_, + stream); + VLOG(3) << "rank " << GetRank() << " recv " + << phi::product(out_tensor->dims()) << " from " << peer; +} + +void XCCLCommContext::AllReduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + phi::ccl::CCLReduceOp reduce_type, + const phi::stream::Stream& stream) const { + phi::distributed::CommStaticCheck::SameShape(*out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_, + phi::AllocationType::CUSTOM); + phi::DeviceManager::CCLAllReduce(device_type_, + const_cast(in_tensor.data()), + out_tensor->data(), + in_tensor.numel(), + phi::ccl::ToCCLDataType(in_tensor.type()), + reduce_type, + xccl_comm_, + stream); +} + +void XCCLCommContext::Reduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + phi::ccl::CCLReduceOp reduce_type, + int root, + const phi::stream::Stream& stream) const { + phi::distributed::CommStaticCheck::SameShape(*out_tensor, + in_tensor, + /*dst_rank*/ root, + /*cur_rank*/ rank_, + size_, + phi::AllocationType::CUSTOM); + phi::DeviceManager::CCLReduce(device_type_, + const_cast(in_tensor.data()), + out_tensor->data(), + in_tensor.numel(), + phi::ccl::ToCCLDataType(in_tensor.type()), + reduce_type, + root, + xccl_comm_, + stream); +} + +void XCCLCommContext::GroupStart() const { + phi::DeviceManager::CCLGroupStart(device_type_); +} +void XCCLCommContext::GroupEnd() const { + phi::DeviceManager::CCLGroupEnd(device_type_); +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/xccl_comm_context.h b/paddle/phi/core/distributed/xccl_comm_context.h new file mode 100644 index 0000000000000000000000000000000000000000..f5a51ab332640ff9dd5febca31023ee49c07d069 --- /dev/null +++ b/paddle/phi/core/distributed/xccl_comm_context.h @@ -0,0 +1,83 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/distributed/comm_context.h" +#include "paddle/phi/core/macros.h" + +#include "paddle/phi/backends/device_manager.h" + +namespace phi { +class DenseTensor; +namespace distributed { + +class XCCLCommContext final : public CommContext { + public: + XCCLCommContext(const std::string& device_type, + int rank, + int size, + const ccl::CCLRootId& xccl_id); + + ccl::CCLComm GetXcclComm() const { return xccl_comm_; } + + const std::string& GetDeviceType() const { return device_type_; } + + void Broadcast(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int root, + const phi::stream::Stream& stream) const; + + void Send(const phi::DenseTensor& in_tensor, + const int64_t& count, + const int& peer, + const phi::stream::Stream& stream) const; + + void Recv(phi::DenseTensor* out_tensor, + const int64_t& count, + const int& peer, + const phi::stream::Stream& stream) const; + + void ReduceScatter(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + phi::ccl::CCLReduceOp reduce_type, + const phi::stream::Stream& stream) const; + + void AllGather(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const phi::stream::Stream& stream) const; + + void AllReduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + phi::ccl::CCLReduceOp reduce_type, + const phi::stream::Stream& stream) const; + + void Reduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + phi::ccl::CCLReduceOp reduce_type, + int root, + const phi::stream::Stream& stream) const; + + void GroupStart() const; + + void GroupEnd() const; + + private: + DISABLE_COPY_AND_ASSIGN(XCCLCommContext); + + std::string device_type_; + ccl::CCLComm xccl_comm_; +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/kernels/cpu/all_gather_kernel.cc b/paddle/phi/kernels/cpu/all_gather_kernel.cc index ff2af4ca5591ef8be8b5d136e19a1e383637d0ae..96433694ffb2b64d9f7705683cc308c84404f90d 100644 --- a/paddle/phi/kernels/cpu/all_gather_kernel.cc +++ b/paddle/phi/kernels/cpu/all_gather_kernel.cc @@ -20,6 +20,9 @@ #if defined(PADDLE_WITH_GLOO) #include "paddle/phi/core/distributed/gloo_comm_context.h" #endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE +#include "paddle/phi/core/distributed/xccl_comm_context.h" +#endif namespace phi { @@ -49,6 +52,28 @@ void AllGatherKernel(const Context& dev_ctx, #endif } +#ifdef PADDLE_WITH_CUSTOM_DEVICE +template +void AllGatherKernel(const phi::CustomContext& dev_ctx, + const DenseTensor& x, + int nranks, + DenseTensor* out) { + dev_ctx.template Alloc(out); + auto out_dims = x.dims(); + out_dims[0] *= nranks; + out->Resize(out_dims); + + auto comm_ctx = + static_cast(dev_ctx.GetCommContext()); + PADDLE_ENFORCE_EQ( + nranks, + comm_ctx->GetSize(), + errors::InvalidArgument( + "nranks: %s should equal to %s", nranks, comm_ctx->GetSize())); + + comm_ctx->AllGather(out, x, *dev_ctx.GetStream()); +} +#endif } // namespace phi PD_REGISTER_KERNEL(all_gather, @@ -64,3 +89,19 @@ PD_REGISTER_KERNEL(all_gather, int16_t, int64_t, phi::dtype::float16) {} + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +PD_REGISTER_KERNEL(all_gather, + Custom, + ALL_LAYOUT, + phi::AllGatherKernel, + float, + double, + int, + bool, + int8_t, + uint8_t, + int16_t, + int64_t, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/cpu/all_reduce_kernel.cc b/paddle/phi/kernels/cpu/all_reduce_kernel.cc index 920f900492e13bb68e038d942710c6f9ce4255c5..f6fcaeb2fa7c7f817b6cfb5f1c23ecded156b096 100644 --- a/paddle/phi/kernels/cpu/all_reduce_kernel.cc +++ b/paddle/phi/kernels/cpu/all_reduce_kernel.cc @@ -20,6 +20,9 @@ #if defined(PADDLE_WITH_GLOO) #include "paddle/phi/core/distributed/gloo_comm_context.h" #endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE +#include "paddle/phi/core/distributed/xccl_comm_context.h" +#endif namespace phi { @@ -47,6 +50,27 @@ void AllReduceKernel(const Context& dev_ctx, #endif } +#ifdef PADDLE_WITH_CUSTOM_DEVICE +template +void AllReduceKernel(const phi::CustomContext& dev_ctx, + const DenseTensor& x, + int reduce_type, + DenseTensor* out) { + out->Resize(x.dims()); + dev_ctx.template Alloc(out); + + auto comm_ctx = + static_cast(dev_ctx.GetCommContext()); + PADDLE_ENFORCE_NE( + comm_ctx, + nullptr, + errors::Unavailable("XCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + comm_ctx->AllReduce( + out, x, phi::ccl::ToXCCLReduceOp(reduce_type), *dev_ctx.GetStream()); +} +#endif + } // namespace phi PD_REGISTER_KERNEL(all_reduce, @@ -61,3 +85,18 @@ PD_REGISTER_KERNEL(all_reduce, uint8_t, int64_t, phi::dtype::float16) {} + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +PD_REGISTER_KERNEL(all_reduce, + Custom, + ALL_LAYOUT, + phi::AllReduceKernel, + float, + double, + int, + bool, + int8_t, + uint8_t, + int64_t, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/cpu/all_to_all_kernel.cc b/paddle/phi/kernels/cpu/all_to_all_kernel.cc index 6347d1e3983fb41cc0c36c18104ffcff971a731a..5ea535c6e5e90119259c26d47fe000a26adc0ad3 100644 --- a/paddle/phi/kernels/cpu/all_to_all_kernel.cc +++ b/paddle/phi/kernels/cpu/all_to_all_kernel.cc @@ -16,6 +16,9 @@ #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" +#ifdef PADDLE_WITH_CUSTOM_DEVICE +#include "paddle/phi/core/distributed/xccl_comm_context.h" +#endif namespace phi { @@ -26,6 +29,42 @@ void AllToAllKernel(const Context& dev_ctx UNUSED, PADDLE_THROW( errors::Unimplemented("Unimplemented cpu kernel for all_to_all.")); } +#ifdef PADDLE_WITH_CUSTOM_DEVICE +template +void AllToAllKernel(const phi::CustomContext& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + out->Resize(x.dims()); + dev_ctx.template Alloc(out); + auto comm_ctx = + static_cast(dev_ctx.GetCommContext()); + + int nranks = comm_ctx->GetSize(); + int rank = comm_ctx->GetRank(); + int send_numel = x.numel() / nranks; + + std::vector sendbuf, recvbuf; + std::vector sendsize(send_numel, nranks); + std::vector sendtype( + phi::ccl::ToCCLDataType(x.dtype()), nranks); + for (auto i = 0; i < nranks; ++i) { + sendbuf.push_back(x.data() + i * send_numel); + recvbuf.push_back(out->data() + i * send_numel); + } + phi::DeviceManager::CCLAllToAll(dev_ctx.GetPlace().GetDeviceType(), + const_cast(sendbuf.data()), + sendsize.data(), + sendtype.data(), + recvbuf.data(), + sendsize.data(), + sendtype.data(), + rank, + nranks, + comm_ctx->GetXcclComm(), + *dev_ctx.GetStream()); +} + +#endif } // namespace phi @@ -41,3 +80,17 @@ PD_REGISTER_KERNEL(all_to_all, uint8_t, int64_t, phi::dtype::float16) {} +#ifdef PADDLE_WITH_CUSTOM_DEVICE +PD_REGISTER_KERNEL(all_to_all, + Custom, + ALL_LAYOUT, + phi::AllToAllKernel, + float, + double, + int, + bool, + int8_t, + uint8_t, + int64_t, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/cpu/reduce_kernel.cc b/paddle/phi/kernels/cpu/reduce_kernel.cc index d4650733f49830e5372d4da266981cad6293f39f..83b3fd99c124e184caafcb3fde8a9c5fb73fc12e 100644 --- a/paddle/phi/kernels/cpu/reduce_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_kernel.cc @@ -20,6 +20,9 @@ #if defined(PADDLE_WITH_GLOO) #include "paddle/phi/core/distributed/gloo_comm_context.h" #endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE +#include "paddle/phi/core/distributed/xccl_comm_context.h" +#endif namespace phi { @@ -51,6 +54,35 @@ void ReduceKernel(const Context& dev_ctx, #endif } +#ifdef PADDLE_WITH_CUSTOM_DEVICE +template +void ReduceKernel(const phi::CustomContext& dev_ctx, + const DenseTensor& x, + int root, + int reduce_type, + DenseTensor* out) { + PADDLE_ENFORCE_GT( + x.numel(), + 0, + phi::errors::InvalidArgument("Tensor need be reduced must not empyt.")); + out->Resize(x.dims()); + dev_ctx.template Alloc(out); + + auto comm_ctx = + static_cast(dev_ctx.GetCommContext()); + PADDLE_ENFORCE_NE( + comm_ctx, + nullptr, + errors::Unavailable("XCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + comm_ctx->Reduce(out, + x, + phi::ccl::ToXCCLReduceOp(reduce_type), + root, + *dev_ctx.GetStream()); +} +#endif + } // namespace phi PD_REGISTER_KERNEL(reduce, @@ -65,3 +97,18 @@ PD_REGISTER_KERNEL(reduce, uint8_t, int64_t, phi::dtype::float16) {} + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +PD_REGISTER_KERNEL(reduce, + Custom, + ALL_LAYOUT, + phi::ReduceKernel, + float, + double, + int, + bool, + int8_t, + uint8_t, + int64_t, + phi::dtype::float16) {} +#endif diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index e27557df6af9c85b5e1838d2520b2d190813a120..b1d61196e056d93367a78f4a9d996efa6e1febed 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -334,3 +334,9 @@ def _init_parallel_env(backend): core.CommContextManager.create_nccl_comm_context( store, "0", rank, world_size ) + elif backend == "xccl": + dev_type = global_env.device_type + paddle.device.set_device(f"{dev_type}:{dev_id}") + core.CommContextManager.create_xccl_comm_context( + store, "0", rank, world_size, dev_type + )