未验证 提交 d03bbefa 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add inference MP support, PART0 (#53719)

* [CustomDevice] add inference MP support, PART0

* update
上级 eb97f4f0
......@@ -123,10 +123,26 @@ if(NOT WIN32)
SRCS reducer.cc
DEPS layer)
endif()
if(WITH_CUSTOM_DEVICE)
cc_library(
xccl_context
SRCS xccl_context.cc
DEPS collective_helper device_context tensor var_type_traits)
if(NOT
(WITH_NCCL
OR WITH_RCCL
OR WITH_XPU_BKCL
OR WITH_GLOO))
cc_library(
reducer
SRCS reducer.cc
DEPS layer)
endif()
endif()
if(WITH_NCCL
OR WITH_RCCL
OR WITH_XPU_BKCL)
OR WITH_XPU_BKCL
OR WITH_CUSTOM_DEVICE)
cc_library(
heter_ccl_context
SRCS heter_ccl_context.cc
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/imperative/xccl_context.h"
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#endif
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace imperative {
static void XcclAllReduce(const phi::DenseTensor &src,
phi::DenseTensor *dst,
const phi::stream::Stream &stream,
const phi::ccl::CCLComm &comm) {
const auto &place = src.place();
PADDLE_ENFORCE_EQ(
platform::is_custom_place(place),
true,
platform::errors::Unimplemented(
"Dynamic graph mode does not support multi-CPU training yet."));
void *src_ptr = const_cast<void *>(src.data());
dst->Resize(src.dims());
auto *dst_ptr = phi::DeviceContextPool::Instance()
.Get(src.place())
->Alloc(dst, src.dtype());
auto xccl_dtype = phi::ccl::ToCCLDataType(src.dtype());
phi::DeviceManager::CCLAllReduce(place.GetDeviceType(),
src_ptr,
dst_ptr,
src.numel(),
xccl_dtype,
phi::ccl::CCLReduceOp::SUM,
comm,
stream);
}
void XCCLParallelContext::BcastXCCLId(
std::vector<phi::ccl::CCLRootId> &xccl_ids, // NOLINT
int root,
int server_fd) {
if (strategy_.local_rank_ == root) {
std::vector<std::string> other_trainers;
for (auto &ep : strategy_.trainer_endpoints_) {
if (ep != strategy_.current_endpoint_) {
other_trainers.push_back(ep);
}
}
platform::SendBroadCastCommID(other_trainers, &xccl_ids);
} else {
platform::RecvBroadCastCommID(
server_fd, strategy_.current_endpoint_, &xccl_ids);
}
}
void XCCLParallelContext::Init() {
int server_fd = -1;
std::vector<phi::ccl::CCLRootId> xccl_ids;
xccl_ids.resize(strategy_.nrings_);
if (strategy_.local_rank_ == 0) {
// generate the unique ncclid on the root worker
for (size_t i = 0; i < xccl_ids.size(); ++i) {
phi::DeviceManager::CCLGetUniqueId(place_.GetDeviceType(), &xccl_ids[i]);
}
} else {
// FIXME(wangxi): gloo will use rank0 endpoint, so not create socket server
// on rank0.
server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_)
.socket();
}
BcastXCCLId(xccl_ids, 0, server_fd);
int dev_id = place_.device;
for (int ring_id = 0; ring_id < strategy_.nrings_; ring_id++) {
VLOG(0) << "init nccl context nranks: " << strategy_.nranks_
<< " local rank: " << strategy_.local_rank_ << " dev id: " << dev_id
<< " ring id: " << ring_id;
// it will assign nccl_comm in phi::CustomContext within ring_id
platform::XCCLCommContext::Instance(place_.GetDeviceType())
.CreateComm(&xccl_ids[ring_id],
strategy_.nranks_,
strategy_.local_rank_,
dev_id,
ring_id);
auto compute_event = std::make_shared<phi::event::Event>();
auto comm_event = std::make_shared<phi::event::Event>();
compute_event->Init(place_);
comm_event->Init(place_);
compute_events_.emplace_back(compute_event);
comm_events_.emplace_back(comm_event);
}
}
void XCCLParallelContext::InitWithRingID(int ring_id) {
int server_fd = -1;
std::vector<phi::ccl::CCLRootId> xccl_ids;
xccl_ids.resize(1);
if (strategy_.local_rank_ == 0) {
// generate the unique ncclid on the root worker
phi::DeviceManager::CCLGetUniqueId(place_.GetDeviceType(), &xccl_ids[0]);
} else {
// FIXME(wangxi): gloo will use rank0 endpoint, so not create socket server
// on rank0.
server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_)
.socket();
}
BcastXCCLId(xccl_ids, 0, server_fd);
int dev_id = place_.device;
VLOG(0) << "init xccl context nranks: " << strategy_.nranks_
<< " local rank: " << strategy_.local_rank_ << " dev id: " << dev_id
<< " ring id: " << ring_id;
// it will assign xccl_comm in phi::CustomContext within ring_id
platform::XCCLCommContext::Instance(place_.GetDeviceType())
.CreateComm(&xccl_ids[0],
strategy_.nranks_,
strategy_.local_rank_,
dev_id,
ring_id);
auto compute_event = std::make_shared<phi::event::Event>();
auto comm_event = std::make_shared<phi::event::Event>();
compute_event->Init(place_);
comm_event->Init(place_);
compute_events_.emplace_back(compute_event);
comm_events_.emplace_back(comm_event);
}
void XCCLParallelContext::AllReduceByStream(const framework::Variable &src,
framework::Variable *dst,
int ring_id,
bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
platform::is_custom_place(place_),
true,
platform::errors::Unimplemented(
"Dynamic graph mode does not support multi-CPU training yet."));
auto place = place_;
auto *dev_ctx = static_cast<platform::CustomDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
platform::XCCLComm *comm =
platform::XCCLCommContext::Instance(place.GetDeviceType())
.Get(ring_id, place);
auto stream = use_calc_stream ? dev_ctx->GetStream() : comm->stream();
if (src.IsType<phi::DenseTensor>()) {
if (!dst->IsType<phi::DenseTensor>()) {
dst->Clear();
}
XcclAllReduce(src.Get<phi::DenseTensor>(),
dst->GetMutable<phi::DenseTensor>(),
*stream,
comm->comm());
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"custom device unsupported variable type %s for imperative allreduce, "
"only "
"LoDTensor are supported.",
platform::demangle(framework::ToTypeName(src.Type()))));
}
}
void XCCLParallelContext::Broadcast(framework::Variable *src, int ring_id) {
VLOG(3) << "/// DEBUG /// start inter broadcast with ring_id: " << ring_id;
phi::DenseTensor *src_tensor = src->GetMutable<phi::DenseTensor>();
const auto &place = src_tensor->place();
platform::XCCLComm *comm =
platform::XCCLCommContext::Instance(place_.GetDeviceType())
.Get(ring_id, place);
auto stream = comm->stream();
void *src_ptr = src_tensor->data();
auto xccl_dtype = phi::ccl::ToCCLDataType(src_tensor->dtype());
phi::DeviceManager::CCLBroadcast(place_.GetDeviceType(),
src_ptr,
src_tensor->numel(),
xccl_dtype,
0,
comm->comm(),
*stream);
}
paddle::platform::DeviceContext *XCCLParallelContext::GetDeviceContext(
int ring_id) {
return static_cast<platform::DeviceContext *>(
platform::XCCLCommContext::Instance(place_.GetDeviceType())
.Get(ring_id, place_)
->dev_context());
}
void XCCLParallelContext::WaitCompute(int ring_id) {
PADDLE_ENFORCE_GE(
ring_id,
0,
platform::errors::OutOfRange("ring id must >= 0, but got %d", ring_id));
PADDLE_ENFORCE_LT(ring_id,
compute_events_.size(),
platform::errors::OutOfRange(
"ring id must < compute events size,"
"but got ring id = %d, compute events size = %d",
ring_id,
compute_events_.size()));
auto compute_stream = static_cast<phi::CustomContext *>(
platform::DeviceContextPool::Instance().Get(place_))
->GetStream();
auto comm_stream = platform::XCCLCommContext::Instance(place_.GetDeviceType())
.Get(ring_id, place_)
->stream();
auto event = compute_events_[ring_id].get();
// compute_stream-->event-->comm_stream
event->Record(compute_stream.get());
comm_stream->WaitEvent(event);
}
void XCCLParallelContext::WaitComm(int ring_id) {
PADDLE_ENFORCE_GE(
ring_id,
0,
platform::errors::OutOfRange("ring id must >= 0, but got %d", ring_id));
PADDLE_ENFORCE_LT(ring_id,
comm_events_.size(),
platform::errors::OutOfRange(
"ring id must < comm events size,"
"but got ring id = %d, comm events size = %d",
ring_id,
comm_events_.size()));
auto compute_stream = static_cast<phi::CustomContext *>(
platform::DeviceContextPool::Instance().Get(place_))
->GetStream();
auto comm_stream = platform::XCCLCommContext::Instance(place_.GetDeviceType())
.Get(ring_id, place_)
->stream();
auto event = comm_events_[ring_id].get();
// comm_stream-->event-->compute_stream
event->Record(comm_stream.get());
compute_stream->WaitEvent(event);
}
void XCCLParallelContext::SynchronizeCompute() {
auto *compute_dev_ctx = static_cast<phi::CustomContext *>(
platform::DeviceContextPool::Instance().Get(place_));
compute_dev_ctx->Wait();
}
} // namespace imperative
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/imperative/parallel_context.h"
namespace paddle {
namespace framework {
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace imperative {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
class XCCLParallelContext : public ParallelContext {
public:
explicit XCCLParallelContext(const ParallelStrategy& strategy,
const platform::Place& place)
: ParallelContext(strategy, place) {}
~XCCLParallelContext() override = default;
void BcastXCCLId(std::vector<phi::ccl::CCLRootId>& xccl_ids, // NOLINT
int root,
int server_fd);
void Init() override;
void InitWithRingID(int ring_id) override;
void AllReduceByStream(const framework::Variable& src,
framework::Variable* dst,
int ring_id,
bool use_calc_stream) override;
void Broadcast(framework::Variable* src, int ring_id) override;
paddle::platform::DeviceContext* GetDeviceContext(int ring_id) override;
void WaitCompute(int ring_id) override;
void WaitComm(int ring_id) override;
void SynchronizeCompute() override;
private:
// used for comm wait compute, compute_stream-->event-->comm_stream[ring_id]
std::vector<std::shared_ptr<phi::event::Event>> compute_events_;
// used for compute wait comm, comm_stream[ring_id]-->event-->compute_stream
std::vector<std::shared_ptr<phi::event::Event>> comm_events_;
};
#endif
} // namespace imperative
} // namespace paddle
......@@ -404,5 +404,235 @@ void BKCLCommContext::ReleaseBKCLComms() {
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
class XCCLCommImpl : public XCCLComm {
public:
void set_ring_id(int ring_id) { ring_id_ = ring_id; }
int ring_id() const override { return ring_id_; }
void set_nranks(int nranks) { nranks_ = nranks; }
int nranks() const override { return nranks_; }
void set_rank(int rank) { rank_ = rank; }
int rank() const override { return rank_; }
int device_id() const override { return dev_ctx_->GetPlace().device; }
void set_comm(phi::ccl::CCLComm comm) { comm_ = comm; }
phi::ccl::CCLComm comm() const override { return comm_; }
std::shared_ptr<phi::stream::Stream> stream() const override {
return dev_ctx_->GetStream();
}
void set_dev_ctx(std::unique_ptr<phi::CustomContext>&& dev_ctx) {
dev_ctx_ = std::move(dev_ctx);
}
phi::CustomContext* dev_context() const override { return dev_ctx_.get(); }
std::shared_ptr<phi::event::Event> compute_event() const override {
return compute_event_;
}
std::shared_ptr<phi::event::Event> comm_event() const override {
return comm_event_;
}
void set_compute_event(std::shared_ptr<phi::event::Event>&& compute_event) {
compute_event_ = std::move(compute_event);
}
void set_comm_event(std::shared_ptr<phi::event::Event>&& comm_event) {
comm_event_ = std::move(comm_event);
}
private:
int ring_id_;
int nranks_;
int rank_;
phi::ccl::CCLComm comm_;
std::unique_ptr<phi::CustomContext> dev_ctx_;
// used for comm wait compute, compute_stream-->event-->comm_stream
std::shared_ptr<phi::event::Event> compute_event_;
// used for compute wait comm, comm_stream-->event-->compute_stream
std::shared_ptr<phi::event::Event> comm_event_;
};
static std::unordered_map<std::string, std::unique_ptr<XCCLCommContext>>
g_xccl_comm_ctx_map;
void XCCLCommContext::Release() {
for (auto& it : g_xccl_comm_ctx_map) {
it.second->ReleaseXCCLComms();
}
g_xccl_comm_ctx_map.clear();
}
XCCLCommContext& XCCLCommContext::Instance(const std::string& device_type) {
if (g_xccl_comm_ctx_map.find(device_type) == g_xccl_comm_ctx_map.end()) {
g_xccl_comm_ctx_map.insert(
{device_type,
std::unique_ptr<XCCLCommContext>(new XCCLCommContext(device_type))});
}
return *g_xccl_comm_ctx_map[device_type];
}
XCCLComm* XCCLCommContext::CreateComm(phi::ccl::CCLRootId* xccl_id,
int nranks,
int rank,
int dev_id,
int ring_id) {
PADDLE_ENFORCE_NOT_NULL(xccl_id,
platform::errors::InvalidArgument(
"The xccl unique id should not be null."));
PADDLE_ENFORCE_GT(
nranks,
1,
platform::errors::InvalidArgument(
"Expected nranks > 1. But received nranks is %d.", nranks));
PADDLE_ENFORCE_GE(rank,
0,
platform::errors::InvalidArgument(
"Expected rank >= 0. But received rank is %d.", rank));
PADDLE_ENFORCE_LT(
rank,
nranks,
platform::errors::InvalidArgument(
"Expected rank < nranks. But received rank is %d, nranks is %d.",
rank,
nranks));
PADDLE_ENFORCE_GE(
dev_id,
0,
platform::errors::InvalidArgument(
"Expected dev_id >= 0. But received dev_id is %d.", dev_id));
phi::ccl::CCLComm comm = nullptr;
phi::DeviceManager::SetDevice(device_type_, dev_id);
phi::DeviceManager::CCLCommInitRank(
device_type_, nranks, xccl_id, rank, &comm);
auto* comm_wrapper = AssignXCCLComm(comm, nranks, rank, dev_id, ring_id);
VLOG(1) << "xccl communicator of rank " << rank << " in ring " << ring_id
<< " has been created on device " << dev_id;
return comm_wrapper;
}
void XCCLCommContext::CreateXCCLCommMultiTrainer(
const std::vector<int>& dev_ids,
phi::ccl::CCLRootId* xccl_id,
int ntrainers,
int train_id,
int ring_id) {
PADDLE_ENFORCE_GT(
dev_ids.size(),
0,
paddle::platform::errors::InvalidArgument(
"dev ids = [%d], it should greater than 0.", dev_ids.size()));
const int kDevices = dev_ids.size();
VLOG(1) << "Begin CreateXCCLCommMultiTrainer. device number: " << kDevices
<< ", ntrainers: " << ntrainers << ", train_id: " << train_id
<< ", rind_id: " << ring_id;
phi::ccl::CCLComm comms[kDevices];
{
for (int i = 0; i < kDevices; i++) {
phi::DeviceManager::SetDevice(device_type_, i);
phi::DeviceManager::CCLCommInitRank(device_type_,
kDevices * ntrainers,
xccl_id,
train_id * kDevices + i,
comms + i);
VLOG(1) << "CCLCommInitRank: " << i;
}
}
PADDLE_ENFORCE_EQ(comm_map_.count(ring_id),
0,
platform::errors::InvalidArgument(
"comm_map_ of ring_id: %s should be 0. %s is provided",
ring_id,
comm_map_.count(ring_id)));
for (int i = 0; i < kDevices; ++i) {
AssignXCCLComm(comms[i],
kDevices * ntrainers,
train_id * kDevices + i,
dev_ids[i],
ring_id);
VLOG(1) << "xccl communicator of train_id " << train_id * kDevices + i
<< " in ring " << ring_id << " has been created on device "
<< dev_ids[i];
}
}
XCCLComm* XCCLCommContext::AssignXCCLComm(
phi::ccl::CCLComm comm, int nranks, int rank, int dev_id, int ring_id) {
auto place = CustomPlace(device_type_, dev_id);
std::unique_ptr<phi::CustomContext> dev_ctx(new phi::CustomContext(place));
dev_ctx->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(place)
.get());
dev_ctx->SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx->SetZeroAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(place)
.get());
dev_ctx->SetHostZeroAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx->SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
// dev_ctx->PartialInitWithAllocator();
auto compute_event = std::make_shared<phi::event::Event>();
auto comm_event = std::make_shared<phi::event::Event>();
compute_event->Init(place);
comm_event->Init(place);
auto* c = new XCCLCommImpl;
c->set_ring_id(ring_id);
c->set_nranks(nranks);
c->set_rank(rank);
c->set_comm(comm);
c->set_dev_ctx(std::move(dev_ctx));
c->set_compute_event(std::move(compute_event));
c->set_comm_event(std::move(comm_event));
comm_map_mutex_.lock();
if (comm_map_.count(ring_id) == 0) {
comm_map_.emplace(ring_id, std::map<int, std::unique_ptr<XCCLComm>>());
}
auto& dev2comm = comm_map_[ring_id];
dev2comm.emplace(dev_id, std::unique_ptr<XCCLComm>(c));
comm_map_mutex_.unlock();
if (ring_id == 0) {
auto* dev_ctx = static_cast<phi::CustomContext*>(
platform::DeviceContextPool::Instance().Get(place));
dev_ctx->set_xccl_comm(comm);
}
VLOG(4) << "add xccl comm: " << comm_map_[ring_id][dev_id].get()
<< ", ring_id:" << ring_id << ", dev_id:" << dev_id;
return comm_map_[ring_id][dev_id].get();
}
void XCCLCommContext::ReleaseXCCLComms() {
for (auto& p : comm_map_) {
for (auto& q : p.second) {
q.second.reset();
}
}
}
#endif
} // namespace platform
} // namespace paddle
......@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/utils/variant.h"
namespace paddle {
......@@ -243,5 +244,112 @@ class BKCLCommContext {
DISABLE_COPY_AND_ASSIGN(BKCLCommContext);
};
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
class XCCLComm {
public:
virtual int ring_id() const = 0;
virtual int nranks() const = 0;
virtual int rank() const = 0;
virtual int device_id() const = 0;
virtual phi::ccl::CCLComm comm() const = 0;
virtual std::shared_ptr<phi::stream::Stream> stream() const = 0;
virtual std::shared_ptr<phi::event::Event> compute_event() const = 0;
virtual std::shared_ptr<phi::event::Event> comm_event() const = 0;
virtual phi::CustomContext* dev_context() const = 0;
virtual ~XCCLComm() = default;
};
// A singleton XCCL communicator context reserves communication ring ids
class XCCLCommContext {
public:
static XCCLCommContext& Instance(const std::string& device_type);
static void Release();
XCCLComm* CreateComm(phi::ccl::CCLRootId* nccl_id,
int nranks,
int rank,
int dev_id,
int ring_id = 0);
void CreateAllXCCLComms(const std::vector<int>& dev_ids, int ring_id = 0);
void CreateXCCLCommMultiTrainer(const std::vector<int>& dev_ids,
phi::ccl::CCLRootId* xccl_id,
int nranks,
int rank,
int ring_id);
// a latter comm with the same dev_id and the same ring_id
// will override the former
XCCLComm* AssignXCCLComm(phi::ccl::CCLComm comm,
int nranks,
int rank,
int dev_id,
int ring_id = 0);
// retrieve a communicator by the ring id in multiprocessing mode
XCCLComm* Get(int ring_id) const {
PADDLE_ENFORCE_GT(
comm_map_.count(ring_id),
0,
platform::errors::InvalidArgument(
"Communicator in ring id %d has not been initialized.", ring_id));
PADDLE_ENFORCE_EQ(comm_map_.at(ring_id).size(),
1,
platform::errors::InvalidArgument(
"One device id should be specified to retrieve from "
"multiple communicators."));
return comm_map_.at(ring_id).begin()->second.get();
}
int GetRingId(phi::ccl::CCLComm comm) const {
for (const auto& pair : comm_map_) {
for (const auto& p : pair.second) {
if (p.second.get()->comm() == comm) {
return pair.first;
}
}
}
return -1;
}
// retrieve a communicator by the ring id and the device id
XCCLComm* Get(int ring_id, int dev_id) const {
PADDLE_ENFORCE_GT(
comm_map_.count(ring_id),
0,
platform::errors::InvalidArgument(
"Communicator of ring id %d has not been initialized.", ring_id));
PADDLE_ENFORCE_GT(
comm_map_.at(ring_id).count(dev_id),
0,
platform::errors::InvalidArgument(
"Communicator at device id %d has not been initialized in ring %d.",
dev_id,
ring_id));
return comm_map_.at(ring_id).at(dev_id).get();
}
// retrieve a communicator by the ring id and place
XCCLComm* Get(int ring_id, Place place) const {
return Get(ring_id, place.device);
}
private:
std::string device_type_;
std::once_flag once_flag_;
std::mutex comm_map_mutex_;
// ring id to dev-XCCLComm
std::map<int, std::map<int, std::unique_ptr<XCCLComm>>> comm_map_;
void ReleaseXCCLComms();
XCCLCommContext() = default;
explicit XCCLCommContext(const std::string& device_type)
: device_type_(device_type) {}
DISABLE_COPY_AND_ASSIGN(XCCLCommContext);
};
#endif
} // namespace platform
} // namespace paddle
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CUSTOM_DEVICE)
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include <arpa/inet.h>
......@@ -33,6 +33,9 @@ limitations under the License. */
#if defined(PADDLE_WITH_XPU_BKCL)
#include "xpu/bkcl.h"
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
#include "paddle/phi/backends/c_comm_lib.h"
#endif
PHI_DECLARE_int32(get_host_by_name_time);
......@@ -348,6 +351,58 @@ static void SendCommID(int conn, CommUniqueId* nccl_id) {
"send comm unique id");
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <>
void RecvCommID<phi::ccl::CCLRootId>(int conn, phi::ccl::CCLRootId* nccl_id) {
char buffer[MAX_COMMUNIQUEID_LEN] = {0};
CHECK_SYS_CALL(SocketRecv(conn, buffer, sizeof(size_t)),
"recv comm unique id size");
size_t unique_id_size = *reinterpret_cast<size_t*>(buffer);
VLOG(6) << "RecvCommID size: " << unique_id_size;
nccl_id->resize(unique_id_size);
size_t n_repeat = unique_id_size / MAX_COMMUNIQUEID_LEN;
size_t n_remain = unique_id_size % MAX_COMMUNIQUEID_LEN;
for (size_t i = 0; i < n_repeat; ++i) {
CHECK_SYS_CALL(SocketRecv(conn, buffer, MAX_COMMUNIQUEID_LEN),
"recv comm unique id");
memcpy(nccl_id->data() + i * MAX_COMMUNIQUEID_LEN,
buffer,
MAX_COMMUNIQUEID_LEN);
}
if (n_remain) {
CHECK_SYS_CALL(SocketRecv(conn, buffer, n_remain), "recv comm unique id");
memcpy(nccl_id->data() + n_repeat * MAX_COMMUNIQUEID_LEN, buffer, n_remain);
}
VLOG(6) << "RecvCommID done";
}
template <>
void SendCommID<phi::ccl::CCLRootId>(int conn, phi::ccl::CCLRootId* nccl_id) {
char buffer[MAX_COMMUNIQUEID_LEN] = {0};
size_t unique_id_size = nccl_id->size();
VLOG(6) << "SendCommID size: " << unique_id_size;
memcpy(buffer, &unique_id_size, sizeof(size_t));
CHECK_SYS_CALL(SocketSend(conn, buffer, sizeof(size_t)),
"send comm unique id size");
size_t n_repeat = unique_id_size / MAX_COMMUNIQUEID_LEN;
size_t n_remain = unique_id_size % MAX_COMMUNIQUEID_LEN;
for (size_t i = 0; i < n_repeat; ++i) {
memcpy(buffer,
nccl_id->data() + i * MAX_COMMUNIQUEID_LEN,
MAX_COMMUNIQUEID_LEN);
CHECK_SYS_CALL(SocketSend(conn, buffer, MAX_COMMUNIQUEID_LEN),
"send comm unique id");
}
if (n_remain) {
memcpy(buffer, nccl_id->data() + n_repeat * MAX_COMMUNIQUEID_LEN, n_remain);
CHECK_SYS_CALL(SocketSend(conn, buffer, n_remain), "send comm unique id");
}
VLOG(6) << "SendCommID done";
}
#endif
template <typename CommUniqueId>
void SendBroadCastCommID(std::vector<std::string> servers,
std::vector<CommUniqueId>* nccl_ids,
......@@ -444,6 +499,9 @@ INSTANT_TEMPLATE(ncclUniqueId)
#ifdef PADDLE_WITH_XPU_BKCL
INSTANT_TEMPLATE(BKCLUniqueId)
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
INSTANT_TEMPLATE(phi::ccl::CCLRootId)
#endif
} // namespace platform
} // namespace paddle
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CUSTOM_DEVICE)
#include <functional>
#include <memory>
#include <mutex>
......
......@@ -23,8 +23,8 @@
namespace phi {
namespace ccl {
using CCLComm = void*;
using CCLRootId = std::vector<uint8_t>;
typedef void* CCLComm;
typedef std::vector<uint8_t> CCLRootId;
enum CCLReduceOp { SUM = 0, AVG, MAX, MIN, PRODUCT };
enum CCLDataType {
......
......@@ -44,9 +44,15 @@ struct CustomContext::Impl {
void Wait() const { stream_->Wait(); }
phi::ccl::CCLComm xccl_comm() const { return comm_; }
void set_xccl_comm(phi::ccl::CCLComm comm) { comm_ = comm; }
Place place_;
std::shared_ptr<phi::stream::Stream> stream_;
phi::ccl::CCLComm comm_;
};
void CustomContext::Init() { impl_->Init(); }
......@@ -72,4 +78,11 @@ CustomContext::CustomContext(const CustomPlace& place)
CustomContext::~CustomContext() { impl_->Init(); }
phi::ccl::CCLComm CustomContext::xccl_comm() const {
return impl_->xccl_comm();
}
void CustomContext::set_xccl_comm(phi::ccl::CCLComm comm) {
impl_->set_xccl_comm(comm);
}
} // namespace phi
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <memory>
#include "paddle/phi/backends/c_comm_lib.h"
#include "paddle/phi/backends/stream.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
......@@ -63,6 +64,12 @@ class CustomContext : public DeviceContext,
// all resources and delete them when destructing.
void Init();
/*! \brief Return xccl communicators. */
phi::ccl::CCLComm xccl_comm() const;
/*! \brief Set nccl communicators. */
void set_xccl_comm(phi::ccl::CCLComm comm);
private:
CustomContext();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册