From d03bbefa96ee41100a3e1d9a26e6986845bfc918 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Fri, 12 May 2023 14:43:38 +0800 Subject: [PATCH] [CustomDevice] add inference MP support, PART0 (#53719) * [CustomDevice] add inference MP support, PART0 * update --- paddle/fluid/imperative/CMakeLists.txt | 20 +- paddle/fluid/imperative/xccl_context.cc | 282 +++++++++++++++++++ paddle/fluid/imperative/xccl_context.h | 71 +++++ paddle/fluid/platform/collective_helper.cc | 230 +++++++++++++++ paddle/fluid/platform/collective_helper.h | 108 +++++++ paddle/fluid/platform/gen_comm_id_helper.cc | 60 +++- paddle/fluid/platform/gen_comm_id_helper.h | 2 +- paddle/phi/backends/c_comm_lib.h | 4 +- paddle/phi/backends/custom/custom_context.cc | 13 + paddle/phi/backends/custom/custom_context.h | 7 + 10 files changed, 791 insertions(+), 6 deletions(-) create mode 100644 paddle/fluid/imperative/xccl_context.cc create mode 100644 paddle/fluid/imperative/xccl_context.h diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 8ef07cfa76e..f6fe845b30c 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -123,10 +123,26 @@ if(NOT WIN32) SRCS reducer.cc DEPS layer) endif() - + if(WITH_CUSTOM_DEVICE) + cc_library( + xccl_context + SRCS xccl_context.cc + DEPS collective_helper device_context tensor var_type_traits) + if(NOT + (WITH_NCCL + OR WITH_RCCL + OR WITH_XPU_BKCL + OR WITH_GLOO)) + cc_library( + reducer + SRCS reducer.cc + DEPS layer) + endif() + endif() if(WITH_NCCL OR WITH_RCCL - OR WITH_XPU_BKCL) + OR WITH_XPU_BKCL + OR WITH_CUSTOM_DEVICE) cc_library( heter_ccl_context SRCS heter_ccl_context.cc diff --git a/paddle/fluid/imperative/xccl_context.cc b/paddle/fluid/imperative/xccl_context.cc new file mode 100644 index 00000000000..dc7a4b939b7 --- /dev/null +++ b/paddle/fluid/imperative/xccl_context.cc @@ -0,0 +1,282 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/imperative/xccl_context.h" + +#if defined(PADDLE_WITH_CUSTOM_DEVICE) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/gen_comm_id_helper.h" +#endif + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace framework { +class Variable; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace imperative { + +static void XcclAllReduce(const phi::DenseTensor &src, + phi::DenseTensor *dst, + const phi::stream::Stream &stream, + const phi::ccl::CCLComm &comm) { + const auto &place = src.place(); + PADDLE_ENFORCE_EQ( + platform::is_custom_place(place), + true, + platform::errors::Unimplemented( + "Dynamic graph mode does not support multi-CPU training yet.")); + + void *src_ptr = const_cast(src.data()); + dst->Resize(src.dims()); + auto *dst_ptr = phi::DeviceContextPool::Instance() + .Get(src.place()) + ->Alloc(dst, src.dtype()); + auto xccl_dtype = phi::ccl::ToCCLDataType(src.dtype()); + + phi::DeviceManager::CCLAllReduce(place.GetDeviceType(), + src_ptr, + dst_ptr, + src.numel(), + xccl_dtype, + phi::ccl::CCLReduceOp::SUM, + comm, + stream); +} + +void XCCLParallelContext::BcastXCCLId( + std::vector &xccl_ids, // NOLINT + int root, + int server_fd) { + if (strategy_.local_rank_ == root) { + std::vector other_trainers; + for (auto &ep : strategy_.trainer_endpoints_) { + if (ep != strategy_.current_endpoint_) { + other_trainers.push_back(ep); + } + } + platform::SendBroadCastCommID(other_trainers, &xccl_ids); + } else { + platform::RecvBroadCastCommID( + server_fd, strategy_.current_endpoint_, &xccl_ids); + } +} + +void XCCLParallelContext::Init() { + int server_fd = -1; + + std::vector xccl_ids; + xccl_ids.resize(strategy_.nrings_); + + if (strategy_.local_rank_ == 0) { + // generate the unique ncclid on the root worker + for (size_t i = 0; i < xccl_ids.size(); ++i) { + phi::DeviceManager::CCLGetUniqueId(place_.GetDeviceType(), &xccl_ids[i]); + } + } else { + // FIXME(wangxi): gloo will use rank0 endpoint, so not create socket server + // on rank0. + server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_) + .socket(); + } + BcastXCCLId(xccl_ids, 0, server_fd); + + int dev_id = place_.device; + for (int ring_id = 0; ring_id < strategy_.nrings_; ring_id++) { + VLOG(0) << "init nccl context nranks: " << strategy_.nranks_ + << " local rank: " << strategy_.local_rank_ << " dev id: " << dev_id + << " ring id: " << ring_id; + // it will assign nccl_comm in phi::CustomContext within ring_id + platform::XCCLCommContext::Instance(place_.GetDeviceType()) + .CreateComm(&xccl_ids[ring_id], + strategy_.nranks_, + strategy_.local_rank_, + dev_id, + ring_id); + auto compute_event = std::make_shared(); + auto comm_event = std::make_shared(); + compute_event->Init(place_); + comm_event->Init(place_); + compute_events_.emplace_back(compute_event); + comm_events_.emplace_back(comm_event); + } +} + +void XCCLParallelContext::InitWithRingID(int ring_id) { + int server_fd = -1; + std::vector xccl_ids; + xccl_ids.resize(1); + + if (strategy_.local_rank_ == 0) { + // generate the unique ncclid on the root worker + phi::DeviceManager::CCLGetUniqueId(place_.GetDeviceType(), &xccl_ids[0]); + } else { + // FIXME(wangxi): gloo will use rank0 endpoint, so not create socket server + // on rank0. + server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_) + .socket(); + } + BcastXCCLId(xccl_ids, 0, server_fd); + + int dev_id = place_.device; + VLOG(0) << "init xccl context nranks: " << strategy_.nranks_ + << " local rank: " << strategy_.local_rank_ << " dev id: " << dev_id + << " ring id: " << ring_id; + // it will assign xccl_comm in phi::CustomContext within ring_id + platform::XCCLCommContext::Instance(place_.GetDeviceType()) + .CreateComm(&xccl_ids[0], + strategy_.nranks_, + strategy_.local_rank_, + dev_id, + ring_id); + + auto compute_event = std::make_shared(); + auto comm_event = std::make_shared(); + compute_event->Init(place_); + comm_event->Init(place_); + compute_events_.emplace_back(compute_event); + comm_events_.emplace_back(comm_event); +} + +void XCCLParallelContext::AllReduceByStream(const framework::Variable &src, + framework::Variable *dst, + int ring_id, + bool use_calc_stream) { + PADDLE_ENFORCE_EQ( + platform::is_custom_place(place_), + true, + platform::errors::Unimplemented( + "Dynamic graph mode does not support multi-CPU training yet.")); + auto place = place_; + + auto *dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + platform::XCCLComm *comm = + platform::XCCLCommContext::Instance(place.GetDeviceType()) + .Get(ring_id, place); + auto stream = use_calc_stream ? dev_ctx->GetStream() : comm->stream(); + + if (src.IsType()) { + if (!dst->IsType()) { + dst->Clear(); + } + XcclAllReduce(src.Get(), + dst->GetMutable(), + *stream, + comm->comm()); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "custom device unsupported variable type %s for imperative allreduce, " + "only " + "LoDTensor are supported.", + platform::demangle(framework::ToTypeName(src.Type())))); + } +} + +void XCCLParallelContext::Broadcast(framework::Variable *src, int ring_id) { + VLOG(3) << "/// DEBUG /// start inter broadcast with ring_id: " << ring_id; + phi::DenseTensor *src_tensor = src->GetMutable(); + const auto &place = src_tensor->place(); + platform::XCCLComm *comm = + platform::XCCLCommContext::Instance(place_.GetDeviceType()) + .Get(ring_id, place); + auto stream = comm->stream(); + + void *src_ptr = src_tensor->data(); + auto xccl_dtype = phi::ccl::ToCCLDataType(src_tensor->dtype()); + + phi::DeviceManager::CCLBroadcast(place_.GetDeviceType(), + src_ptr, + src_tensor->numel(), + xccl_dtype, + 0, + comm->comm(), + *stream); +} + +paddle::platform::DeviceContext *XCCLParallelContext::GetDeviceContext( + int ring_id) { + return static_cast( + platform::XCCLCommContext::Instance(place_.GetDeviceType()) + .Get(ring_id, place_) + ->dev_context()); +} + +void XCCLParallelContext::WaitCompute(int ring_id) { + PADDLE_ENFORCE_GE( + ring_id, + 0, + platform::errors::OutOfRange("ring id must >= 0, but got %d", ring_id)); + PADDLE_ENFORCE_LT(ring_id, + compute_events_.size(), + platform::errors::OutOfRange( + "ring id must < compute events size," + "but got ring id = %d, compute events size = %d", + ring_id, + compute_events_.size())); + + auto compute_stream = static_cast( + platform::DeviceContextPool::Instance().Get(place_)) + ->GetStream(); + auto comm_stream = platform::XCCLCommContext::Instance(place_.GetDeviceType()) + .Get(ring_id, place_) + ->stream(); + auto event = compute_events_[ring_id].get(); + + // compute_stream-->event-->comm_stream + event->Record(compute_stream.get()); + comm_stream->WaitEvent(event); +} + +void XCCLParallelContext::WaitComm(int ring_id) { + PADDLE_ENFORCE_GE( + ring_id, + 0, + platform::errors::OutOfRange("ring id must >= 0, but got %d", ring_id)); + PADDLE_ENFORCE_LT(ring_id, + comm_events_.size(), + platform::errors::OutOfRange( + "ring id must < comm events size," + "but got ring id = %d, comm events size = %d", + ring_id, + comm_events_.size())); + + auto compute_stream = static_cast( + platform::DeviceContextPool::Instance().Get(place_)) + ->GetStream(); + auto comm_stream = platform::XCCLCommContext::Instance(place_.GetDeviceType()) + .Get(ring_id, place_) + ->stream(); + auto event = comm_events_[ring_id].get(); + + // comm_stream-->event-->compute_stream + event->Record(comm_stream.get()); + compute_stream->WaitEvent(event); +} + +void XCCLParallelContext::SynchronizeCompute() { + auto *compute_dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place_)); + compute_dev_ctx->Wait(); +} + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/xccl_context.h b/paddle/fluid/imperative/xccl_context.h new file mode 100644 index 00000000000..4426a253725 --- /dev/null +++ b/paddle/fluid/imperative/xccl_context.h @@ -0,0 +1,71 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +#include "paddle/fluid/imperative/parallel_context.h" + +namespace paddle { +namespace framework { +class Variable; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace imperative { +#ifdef PADDLE_WITH_CUSTOM_DEVICE +class XCCLParallelContext : public ParallelContext { + public: + explicit XCCLParallelContext(const ParallelStrategy& strategy, + const platform::Place& place) + : ParallelContext(strategy, place) {} + + ~XCCLParallelContext() override = default; + + void BcastXCCLId(std::vector& xccl_ids, // NOLINT + int root, + int server_fd); + + void Init() override; + + void InitWithRingID(int ring_id) override; + + void AllReduceByStream(const framework::Variable& src, + framework::Variable* dst, + int ring_id, + bool use_calc_stream) override; + + void Broadcast(framework::Variable* src, int ring_id) override; + + paddle::platform::DeviceContext* GetDeviceContext(int ring_id) override; + + void WaitCompute(int ring_id) override; + + void WaitComm(int ring_id) override; + + void SynchronizeCompute() override; + + private: + // used for comm wait compute, compute_stream-->event-->comm_stream[ring_id] + std::vector> compute_events_; + + // used for compute wait comm, comm_stream[ring_id]-->event-->compute_stream + std::vector> comm_events_; +}; +#endif +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/platform/collective_helper.cc b/paddle/fluid/platform/collective_helper.cc index 948fa300c1f..b133a57d523 100644 --- a/paddle/fluid/platform/collective_helper.cc +++ b/paddle/fluid/platform/collective_helper.cc @@ -404,5 +404,235 @@ void BKCLCommContext::ReleaseBKCLComms() { #endif +#if defined(PADDLE_WITH_CUSTOM_DEVICE) +class XCCLCommImpl : public XCCLComm { + public: + void set_ring_id(int ring_id) { ring_id_ = ring_id; } + int ring_id() const override { return ring_id_; } + + void set_nranks(int nranks) { nranks_ = nranks; } + int nranks() const override { return nranks_; } + + void set_rank(int rank) { rank_ = rank; } + int rank() const override { return rank_; } + + int device_id() const override { return dev_ctx_->GetPlace().device; } + + void set_comm(phi::ccl::CCLComm comm) { comm_ = comm; } + phi::ccl::CCLComm comm() const override { return comm_; } + + std::shared_ptr stream() const override { + return dev_ctx_->GetStream(); + } + + void set_dev_ctx(std::unique_ptr&& dev_ctx) { + dev_ctx_ = std::move(dev_ctx); + } + phi::CustomContext* dev_context() const override { return dev_ctx_.get(); } + + std::shared_ptr compute_event() const override { + return compute_event_; + } + + std::shared_ptr comm_event() const override { + return comm_event_; + } + + void set_compute_event(std::shared_ptr&& compute_event) { + compute_event_ = std::move(compute_event); + } + + void set_comm_event(std::shared_ptr&& comm_event) { + comm_event_ = std::move(comm_event); + } + + private: + int ring_id_; + int nranks_; + int rank_; + phi::ccl::CCLComm comm_; + std::unique_ptr dev_ctx_; + + // used for comm wait compute, compute_stream-->event-->comm_stream + std::shared_ptr compute_event_; + + // used for compute wait comm, comm_stream-->event-->compute_stream + std::shared_ptr comm_event_; +}; + +static std::unordered_map> + g_xccl_comm_ctx_map; + +void XCCLCommContext::Release() { + for (auto& it : g_xccl_comm_ctx_map) { + it.second->ReleaseXCCLComms(); + } + g_xccl_comm_ctx_map.clear(); +} + +XCCLCommContext& XCCLCommContext::Instance(const std::string& device_type) { + if (g_xccl_comm_ctx_map.find(device_type) == g_xccl_comm_ctx_map.end()) { + g_xccl_comm_ctx_map.insert( + {device_type, + std::unique_ptr(new XCCLCommContext(device_type))}); + } + return *g_xccl_comm_ctx_map[device_type]; +} + +XCCLComm* XCCLCommContext::CreateComm(phi::ccl::CCLRootId* xccl_id, + int nranks, + int rank, + int dev_id, + int ring_id) { + PADDLE_ENFORCE_NOT_NULL(xccl_id, + platform::errors::InvalidArgument( + "The xccl unique id should not be null.")); + PADDLE_ENFORCE_GT( + nranks, + 1, + platform::errors::InvalidArgument( + "Expected nranks > 1. But received nranks is %d.", nranks)); + PADDLE_ENFORCE_GE(rank, + 0, + platform::errors::InvalidArgument( + "Expected rank >= 0. But received rank is %d.", rank)); + PADDLE_ENFORCE_LT( + rank, + nranks, + platform::errors::InvalidArgument( + "Expected rank < nranks. But received rank is %d, nranks is %d.", + rank, + nranks)); + PADDLE_ENFORCE_GE( + dev_id, + 0, + platform::errors::InvalidArgument( + "Expected dev_id >= 0. But received dev_id is %d.", dev_id)); + + phi::ccl::CCLComm comm = nullptr; + phi::DeviceManager::SetDevice(device_type_, dev_id); + phi::DeviceManager::CCLCommInitRank( + device_type_, nranks, xccl_id, rank, &comm); + + auto* comm_wrapper = AssignXCCLComm(comm, nranks, rank, dev_id, ring_id); + + VLOG(1) << "xccl communicator of rank " << rank << " in ring " << ring_id + << " has been created on device " << dev_id; + + return comm_wrapper; +} + +void XCCLCommContext::CreateXCCLCommMultiTrainer( + const std::vector& dev_ids, + phi::ccl::CCLRootId* xccl_id, + int ntrainers, + int train_id, + int ring_id) { + PADDLE_ENFORCE_GT( + dev_ids.size(), + 0, + paddle::platform::errors::InvalidArgument( + "dev ids = [%d], it should greater than 0.", dev_ids.size())); + const int kDevices = dev_ids.size(); + VLOG(1) << "Begin CreateXCCLCommMultiTrainer. device number: " << kDevices + << ", ntrainers: " << ntrainers << ", train_id: " << train_id + << ", rind_id: " << ring_id; + phi::ccl::CCLComm comms[kDevices]; + { + for (int i = 0; i < kDevices; i++) { + phi::DeviceManager::SetDevice(device_type_, i); + phi::DeviceManager::CCLCommInitRank(device_type_, + kDevices * ntrainers, + xccl_id, + train_id * kDevices + i, + comms + i); + VLOG(1) << "CCLCommInitRank: " << i; + } + } + PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), + 0, + platform::errors::InvalidArgument( + "comm_map_ of ring_id: %s should be 0. %s is provided", + ring_id, + comm_map_.count(ring_id))); + for (int i = 0; i < kDevices; ++i) { + AssignXCCLComm(comms[i], + kDevices * ntrainers, + train_id * kDevices + i, + dev_ids[i], + ring_id); + VLOG(1) << "xccl communicator of train_id " << train_id * kDevices + i + << " in ring " << ring_id << " has been created on device " + << dev_ids[i]; + } +} + +XCCLComm* XCCLCommContext::AssignXCCLComm( + phi::ccl::CCLComm comm, int nranks, int rank, int dev_id, int ring_id) { + auto place = CustomPlace(device_type_, dev_id); + std::unique_ptr dev_ctx(new phi::CustomContext(place)); + dev_ctx->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(place) + .get()); + dev_ctx->SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx->SetZeroAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetZeroAllocator(place) + .get()); + dev_ctx->SetHostZeroAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetZeroAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx->SetPinnedAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + // dev_ctx->PartialInitWithAllocator(); + + auto compute_event = std::make_shared(); + auto comm_event = std::make_shared(); + compute_event->Init(place); + comm_event->Init(place); + + auto* c = new XCCLCommImpl; + c->set_ring_id(ring_id); + c->set_nranks(nranks); + c->set_rank(rank); + c->set_comm(comm); + c->set_dev_ctx(std::move(dev_ctx)); + c->set_compute_event(std::move(compute_event)); + c->set_comm_event(std::move(comm_event)); + + comm_map_mutex_.lock(); + if (comm_map_.count(ring_id) == 0) { + comm_map_.emplace(ring_id, std::map>()); + } + auto& dev2comm = comm_map_[ring_id]; + + dev2comm.emplace(dev_id, std::unique_ptr(c)); + comm_map_mutex_.unlock(); + + if (ring_id == 0) { + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + dev_ctx->set_xccl_comm(comm); + } + VLOG(4) << "add xccl comm: " << comm_map_[ring_id][dev_id].get() + << ", ring_id:" << ring_id << ", dev_id:" << dev_id; + return comm_map_[ring_id][dev_id].get(); +} + +void XCCLCommContext::ReleaseXCCLComms() { + for (auto& p : comm_map_) { + for (auto& q : p.second) { + q.second.reset(); + } + } +} + +#endif } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/collective_helper.h b/paddle/fluid/platform/collective_helper.h index c0a7f2f37b0..6636856a0eb 100644 --- a/paddle/fluid/platform/collective_helper.h +++ b/paddle/fluid/platform/collective_helper.h @@ -22,6 +22,7 @@ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/backends/device_manager.h" #include "paddle/utils/variant.h" namespace paddle { @@ -243,5 +244,112 @@ class BKCLCommContext { DISABLE_COPY_AND_ASSIGN(BKCLCommContext); }; #endif + +#if defined(PADDLE_WITH_CUSTOM_DEVICE) +class XCCLComm { + public: + virtual int ring_id() const = 0; + virtual int nranks() const = 0; + virtual int rank() const = 0; + virtual int device_id() const = 0; + virtual phi::ccl::CCLComm comm() const = 0; + virtual std::shared_ptr stream() const = 0; + virtual std::shared_ptr compute_event() const = 0; + virtual std::shared_ptr comm_event() const = 0; + virtual phi::CustomContext* dev_context() const = 0; + virtual ~XCCLComm() = default; +}; + +// A singleton XCCL communicator context reserves communication ring ids +class XCCLCommContext { + public: + static XCCLCommContext& Instance(const std::string& device_type); + static void Release(); + + XCCLComm* CreateComm(phi::ccl::CCLRootId* nccl_id, + int nranks, + int rank, + int dev_id, + int ring_id = 0); + + void CreateAllXCCLComms(const std::vector& dev_ids, int ring_id = 0); + + void CreateXCCLCommMultiTrainer(const std::vector& dev_ids, + phi::ccl::CCLRootId* xccl_id, + int nranks, + int rank, + int ring_id); + + // a latter comm with the same dev_id and the same ring_id + // will override the former + XCCLComm* AssignXCCLComm(phi::ccl::CCLComm comm, + int nranks, + int rank, + int dev_id, + int ring_id = 0); + + // retrieve a communicator by the ring id in multiprocessing mode + XCCLComm* Get(int ring_id) const { + PADDLE_ENFORCE_GT( + comm_map_.count(ring_id), + 0, + platform::errors::InvalidArgument( + "Communicator in ring id %d has not been initialized.", ring_id)); + PADDLE_ENFORCE_EQ(comm_map_.at(ring_id).size(), + 1, + platform::errors::InvalidArgument( + "One device id should be specified to retrieve from " + "multiple communicators.")); + return comm_map_.at(ring_id).begin()->second.get(); + } + + int GetRingId(phi::ccl::CCLComm comm) const { + for (const auto& pair : comm_map_) { + for (const auto& p : pair.second) { + if (p.second.get()->comm() == comm) { + return pair.first; + } + } + } + return -1; + } + + // retrieve a communicator by the ring id and the device id + XCCLComm* Get(int ring_id, int dev_id) const { + PADDLE_ENFORCE_GT( + comm_map_.count(ring_id), + 0, + platform::errors::InvalidArgument( + "Communicator of ring id %d has not been initialized.", ring_id)); + PADDLE_ENFORCE_GT( + comm_map_.at(ring_id).count(dev_id), + 0, + platform::errors::InvalidArgument( + "Communicator at device id %d has not been initialized in ring %d.", + dev_id, + ring_id)); + return comm_map_.at(ring_id).at(dev_id).get(); + } + + // retrieve a communicator by the ring id and place + XCCLComm* Get(int ring_id, Place place) const { + return Get(ring_id, place.device); + } + + private: + std::string device_type_; + std::once_flag once_flag_; + std::mutex comm_map_mutex_; + // ring id to dev-XCCLComm + std::map>> comm_map_; + + void ReleaseXCCLComms(); + + XCCLCommContext() = default; + explicit XCCLCommContext(const std::string& device_type) + : device_type_(device_type) {} + DISABLE_COPY_AND_ASSIGN(XCCLCommContext); +}; +#endif } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/gen_comm_id_helper.cc b/paddle/fluid/platform/gen_comm_id_helper.cc index ca9f9d7c4f8..0237d28e52c 100644 --- a/paddle/fluid/platform/gen_comm_id_helper.cc +++ b/paddle/fluid/platform/gen_comm_id_helper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(PADDLE_WITH_XPU_BKCL) + defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CUSTOM_DEVICE) #include "paddle/fluid/platform/gen_comm_id_helper.h" #include @@ -33,6 +33,9 @@ limitations under the License. */ #if defined(PADDLE_WITH_XPU_BKCL) #include "xpu/bkcl.h" #endif +#if defined(PADDLE_WITH_CUSTOM_DEVICE) +#include "paddle/phi/backends/c_comm_lib.h" +#endif PHI_DECLARE_int32(get_host_by_name_time); @@ -348,6 +351,58 @@ static void SendCommID(int conn, CommUniqueId* nccl_id) { "send comm unique id"); } +#ifdef PADDLE_WITH_CUSTOM_DEVICE +template <> +void RecvCommID(int conn, phi::ccl::CCLRootId* nccl_id) { + char buffer[MAX_COMMUNIQUEID_LEN] = {0}; + CHECK_SYS_CALL(SocketRecv(conn, buffer, sizeof(size_t)), + "recv comm unique id size"); + size_t unique_id_size = *reinterpret_cast(buffer); + VLOG(6) << "RecvCommID size: " << unique_id_size; + nccl_id->resize(unique_id_size); + + size_t n_repeat = unique_id_size / MAX_COMMUNIQUEID_LEN; + size_t n_remain = unique_id_size % MAX_COMMUNIQUEID_LEN; + for (size_t i = 0; i < n_repeat; ++i) { + CHECK_SYS_CALL(SocketRecv(conn, buffer, MAX_COMMUNIQUEID_LEN), + "recv comm unique id"); + memcpy(nccl_id->data() + i * MAX_COMMUNIQUEID_LEN, + buffer, + MAX_COMMUNIQUEID_LEN); + } + if (n_remain) { + CHECK_SYS_CALL(SocketRecv(conn, buffer, n_remain), "recv comm unique id"); + memcpy(nccl_id->data() + n_repeat * MAX_COMMUNIQUEID_LEN, buffer, n_remain); + } + VLOG(6) << "RecvCommID done"; +} + +template <> +void SendCommID(int conn, phi::ccl::CCLRootId* nccl_id) { + char buffer[MAX_COMMUNIQUEID_LEN] = {0}; + size_t unique_id_size = nccl_id->size(); + VLOG(6) << "SendCommID size: " << unique_id_size; + memcpy(buffer, &unique_id_size, sizeof(size_t)); + CHECK_SYS_CALL(SocketSend(conn, buffer, sizeof(size_t)), + "send comm unique id size"); + + size_t n_repeat = unique_id_size / MAX_COMMUNIQUEID_LEN; + size_t n_remain = unique_id_size % MAX_COMMUNIQUEID_LEN; + for (size_t i = 0; i < n_repeat; ++i) { + memcpy(buffer, + nccl_id->data() + i * MAX_COMMUNIQUEID_LEN, + MAX_COMMUNIQUEID_LEN); + CHECK_SYS_CALL(SocketSend(conn, buffer, MAX_COMMUNIQUEID_LEN), + "send comm unique id"); + } + if (n_remain) { + memcpy(buffer, nccl_id->data() + n_repeat * MAX_COMMUNIQUEID_LEN, n_remain); + CHECK_SYS_CALL(SocketSend(conn, buffer, n_remain), "send comm unique id"); + } + VLOG(6) << "SendCommID done"; +} +#endif + template void SendBroadCastCommID(std::vector servers, std::vector* nccl_ids, @@ -444,6 +499,9 @@ INSTANT_TEMPLATE(ncclUniqueId) #ifdef PADDLE_WITH_XPU_BKCL INSTANT_TEMPLATE(BKCLUniqueId) #endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE +INSTANT_TEMPLATE(phi::ccl::CCLRootId) +#endif } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/gen_comm_id_helper.h b/paddle/fluid/platform/gen_comm_id_helper.h index 0aef760fd4e..d97b4131199 100644 --- a/paddle/fluid/platform/gen_comm_id_helper.h +++ b/paddle/fluid/platform/gen_comm_id_helper.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(PADDLE_WITH_XPU_BKCL) + defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CUSTOM_DEVICE) #include #include #include diff --git a/paddle/phi/backends/c_comm_lib.h b/paddle/phi/backends/c_comm_lib.h index c93a1c1d18d..e67530add58 100644 --- a/paddle/phi/backends/c_comm_lib.h +++ b/paddle/phi/backends/c_comm_lib.h @@ -23,8 +23,8 @@ namespace phi { namespace ccl { -using CCLComm = void*; -using CCLRootId = std::vector; +typedef void* CCLComm; +typedef std::vector CCLRootId; enum CCLReduceOp { SUM = 0, AVG, MAX, MIN, PRODUCT }; enum CCLDataType { diff --git a/paddle/phi/backends/custom/custom_context.cc b/paddle/phi/backends/custom/custom_context.cc index c5b7676df48..ddba0baea7e 100644 --- a/paddle/phi/backends/custom/custom_context.cc +++ b/paddle/phi/backends/custom/custom_context.cc @@ -44,9 +44,15 @@ struct CustomContext::Impl { void Wait() const { stream_->Wait(); } + phi::ccl::CCLComm xccl_comm() const { return comm_; } + + void set_xccl_comm(phi::ccl::CCLComm comm) { comm_ = comm; } + Place place_; std::shared_ptr stream_; + + phi::ccl::CCLComm comm_; }; void CustomContext::Init() { impl_->Init(); } @@ -72,4 +78,11 @@ CustomContext::CustomContext(const CustomPlace& place) CustomContext::~CustomContext() { impl_->Init(); } +phi::ccl::CCLComm CustomContext::xccl_comm() const { + return impl_->xccl_comm(); +} + +void CustomContext::set_xccl_comm(phi::ccl::CCLComm comm) { + impl_->set_xccl_comm(comm); +} } // namespace phi diff --git a/paddle/phi/backends/custom/custom_context.h b/paddle/phi/backends/custom/custom_context.h index d12ba20d4b6..68abfeae366 100644 --- a/paddle/phi/backends/custom/custom_context.h +++ b/paddle/phi/backends/custom/custom_context.h @@ -16,6 +16,7 @@ limitations under the License. */ #include +#include "paddle/phi/backends/c_comm_lib.h" #include "paddle/phi/backends/stream.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/device_context.h" @@ -63,6 +64,12 @@ class CustomContext : public DeviceContext, // all resources and delete them when destructing. void Init(); + /*! \brief Return xccl communicators. */ + phi::ccl::CCLComm xccl_comm() const; + + /*! \brief Set nccl communicators. */ + void set_xccl_comm(phi::ccl::CCLComm comm); + private: CustomContext(); -- GitLab