未验证 提交 8525bc63 编写于 作者: L lilong12 提交者: GitHub

add send/recv to/from switch module for PrcoessGroupHeter (#41285) (#41502)

上级 f2a8d053
...@@ -244,3 +244,7 @@ if(WITH_ROCM) ...@@ -244,3 +244,7 @@ if(WITH_ROCM)
string (REPLACE "-Werror" "-Wno-error" CMAKE_C_FLAGS ${CMAKE_C_FLAGS}) string (REPLACE "-Werror" "-Wno-error" CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
endif() endif()
if(WITH_PSCORE OR WITH_PSLIB)
string (REPLACE "-Wnon-virtual-dtor" "-Wno-non-virtual-dtor" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
string (REPLACE "-Wnon-virtual-dtor" "-Wno-non-virtual-dtor" CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
endif()
...@@ -7,14 +7,14 @@ endif() ...@@ -7,14 +7,14 @@ endif()
if(WITH_NCCL) if(WITH_NCCL)
cc_library(processgroup_nccl SRCS ProcessGroupNCCL.cc NCCLTools.cc Common.cc DEPS place cuda_stream enforce collective_helper device_context phi phi_api eager_api) cc_library(processgroup_nccl SRCS ProcessGroupNCCL.cc NCCLTools.cc Common.cc DEPS place cuda_stream enforce collective_helper device_context phi phi_api eager_api)
if (WITH_DISTRIBUTE) if (WITH_DISTRIBUTE AND WITH_PSCORE)
cc_library(processgroup_heter SRCS ProcessGroupHeter.cc NCCLTools.cc Common.cc DEPS place cuda_stream enforce collective_helper device_context phi phi_api eager_api) cc_library(processgroup_heter SRCS ProcessGroupHeter.cc NCCLTools.cc Common.cc DEPS place cuda_stream enforce collective_helper device_context phi phi_api eager_api)
endif() endif()
endif() endif()
if(WITH_ASCEND_CL) if(WITH_ASCEND_CL)
cc_library(processgroup_hccl SRCS ProcessGroupHCCL.cc HCCLTools.cc Common.cc DEPS place npu_stream enforce collective_helper device_context phi phi_api eager_api) cc_library(processgroup_hccl SRCS ProcessGroupHCCL.cc HCCLTools.cc Common.cc DEPS place npu_stream enforce collective_helper device_context phi phi_api eager_api)
if (WITH_DISTRIBUTE) if (WITH_DISTRIBUTE AND WITH_PSCORE)
cc_library(processgroup_heter SRCS ProcessGroupHeter.cc HCCLTools.cc Common.cc DEPS place npu_stream enforce collective_helper device_context phi phi_api eager_api) cc_library(processgroup_heter SRCS ProcessGroupHeter.cc HCCLTools.cc Common.cc DEPS place npu_stream enforce collective_helper device_context phi phi_api eager_api)
endif() endif()
endif() endif()
...@@ -35,10 +35,10 @@ bool ProcessGroup::Task::Wait(std::chrono::milliseconds timeout) { ...@@ -35,10 +35,10 @@ bool ProcessGroup::Task::Wait(std::chrono::milliseconds timeout) {
void ProcessGroup::Task::Synchronize() {} void ProcessGroup::Task::Synchronize() {}
ProcessGroup::ProcessGroup(int rank, int size, int gid) ProcessGroup::ProcessGroup(int rank, int size, int gid)
: rank_(rank), size_(size) { : rank_(rank), size_(size), gid_(gid) {
if (gid != IGNORE_ID) { if (gid != IGNORE_ID) {
auto map = ProcessGroupMapFromGid::getInstance(); auto map = ProcessGroupMapFromGid::getInstance();
map->insert(gid, this); map->insert(gid_, this);
} }
} }
......
...@@ -93,8 +93,8 @@ class ProcessGroup { ...@@ -93,8 +93,8 @@ class ProcessGroup {
} }
virtual void Broadcast(const phi::DenseTensor* in, phi::DenseTensor* out) { virtual void Broadcast(const phi::DenseTensor* in, phi::DenseTensor* out) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::Fatal(
"ProcessGroup%s does not support broadcast for static", "ProcessGroup%s does not support broadcast for static mode runtime",
GetBackendName())); GetBackendName()));
} }
...@@ -148,6 +148,7 @@ class ProcessGroup { ...@@ -148,6 +148,7 @@ class ProcessGroup {
protected: protected:
const int rank_; const int rank_;
const int size_; const int size_;
const int gid_;
}; };
class ProcessGroupMapFromGid { class ProcessGroupMapFromGid {
...@@ -158,17 +159,20 @@ class ProcessGroupMapFromGid { ...@@ -158,17 +159,20 @@ class ProcessGroupMapFromGid {
} }
void insert(int gid, ProcessGroup* pg) { void insert(int gid, ProcessGroup* pg) {
// TODO(sandyhouse): address ut and uncomment the following codes
// PADDLE_ENFORCE_EQ(has(gid), false, // PADDLE_ENFORCE_EQ(has(gid), false,
// platform::errors::PreconditionNotMet( // platform::errors::PreconditionNotMet(
// "The process group with id %d does exist.", gid)); // "The process group with id %d doesnot exist.",
// gid));
map_[gid] = pg; map_[gid] = pg;
} }
ProcessGroup* get(int gid) { ProcessGroup* get(int gid) {
// TODO(sandyhouse): address ut and uncomment the following codes
// PADDLE_ENFORCE_EQ(has(gid), true, // PADDLE_ENFORCE_EQ(has(gid), true,
// platform::errors::PreconditionNotMet( // platform::errors::PreconditionNotMet(
// "The process group with id %d doesnot exist.", // "The process group with id %d doesnot exist.",
// gid)); // gid));
return map_.find(gid)->second; return map_.find(gid)->second;
} }
......
...@@ -30,12 +30,6 @@ constexpr int64_t kWaitBlockTImeout = 10; ...@@ -30,12 +30,6 @@ constexpr int64_t kWaitBlockTImeout = 10;
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
// bool CheckTensorsInNPUPlace(const std::vector<Tensor>& tensors) {
// return std::all_of(tensors.cbegin(), tensors.cend(), [&](const Tensor& t) {
// return t.place() == platform::DeviceType::NPU;
// });
// }
void SyncDefaultStream( void SyncDefaultStream(
const std::vector<Place>& places, const std::vector<Place>& places,
std::vector<NPUEventManager>& hcclEvents, // NOLINT std::vector<NPUEventManager>& hcclEvents, // NOLINT
......
...@@ -56,7 +56,8 @@ ProcessGroupHeter::ProcessGroupHeter(const std::shared_ptr<Store>& store, ...@@ -56,7 +56,8 @@ ProcessGroupHeter::ProcessGroupHeter(const std::shared_ptr<Store>& store,
local_size_(local_size), local_size_(local_size),
gloo_rank_(gloo_rank), gloo_rank_(gloo_rank),
gloo_size_(gloo_size), gloo_size_(gloo_size),
with_switch_(with_switch) { with_switch_(with_switch),
switch_endpoint_(switch_endpoint) {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
inner_pg_ = std::make_shared<ProcessGroupNCCL>(store, local_rank, local_size, inner_pg_ = std::make_shared<ProcessGroupNCCL>(store, local_rank, local_size,
IGNORE_ID); IGNORE_ID);
...@@ -64,14 +65,10 @@ ProcessGroupHeter::ProcessGroupHeter(const std::shared_ptr<Store>& store, ...@@ -64,14 +65,10 @@ ProcessGroupHeter::ProcessGroupHeter(const std::shared_ptr<Store>& store,
inner_pg_ = std::make_shared<ProcessGroupHCCL>(store, local_rank, local_size, inner_pg_ = std::make_shared<ProcessGroupHCCL>(store, local_rank, local_size,
IGNORE_ID); IGNORE_ID);
#else #else
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::Fatal(
"ProcessGroupHeter only supports NCCL and HCCL now."); "ProcessGroupHeter only supports NCCL and HCCL now.");
#endif #endif
if (with_switch_) { if (local_rank_ == 0 && !with_switch_) {
// TODO(sandyhouse) starts a client to connect the cloud switch module
// std::shared_ptr<HeterClient> client_ =
// HeterClient::GetInstance({switch_endpoint}, {}, 0);
} else if (local_rank_ == 0) {
auto opts = ProcessGroupGloo::GlooOptions::create(); auto opts = ProcessGroupGloo::GlooOptions::create();
opts->device = ProcessGroupGloo::createDefaultDevice(); opts->device = ProcessGroupGloo::createDefaultDevice();
inter_pg_ = std::make_shared<ProcessGroupGloo>(store, gloo_rank_, inter_pg_ = std::make_shared<ProcessGroupGloo>(store, gloo_rank_,
...@@ -79,6 +76,15 @@ ProcessGroupHeter::ProcessGroupHeter(const std::shared_ptr<Store>& store, ...@@ -79,6 +76,15 @@ ProcessGroupHeter::ProcessGroupHeter(const std::shared_ptr<Store>& store,
} }
} }
template <typename T>
static void _do_add(T* dst, T* src, size_t size) {
for (size_t i = 0; i < size; i++) {
*dst += *src;
dst++;
src++;
}
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce( std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce(
std::vector<Tensor>& tensors, const AllreduceOptions& opts) { std::vector<Tensor>& tensors, const AllreduceOptions& opts) {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
...@@ -93,33 +99,92 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce( ...@@ -93,33 +99,92 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::AllReduce(
// Step2: copy tensors to CPU // Step2: copy tensors to CPU
if (local_rank_ == 0) { if (local_rank_ == 0) {
std::vector<Tensor> cpu_tensors(tensors.size()); std::vector<Tensor> cpu_tensors;
cpu_tensors.reserve(tensors.size());
for (size_t i = 0; i < tensors.size(); i++) { for (size_t i = 0; i < tensors.size(); i++) {
auto dense_gpu_tensor = auto dense_gpu_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl());
auto dense_cpu_tensor = phi::DenseTensorMeta meta = phi::DenseTensorMeta(
std::dynamic_pointer_cast<phi::DenseTensor>(cpu_tensors[i].impl()); dense_gpu_tensor->dtype(), dense_gpu_tensor->dims());
dense_cpu_tensor->Resize(tensors[i].dims()); std::shared_ptr<phi::DenseTensor> dense_cpu_tensor =
std::make_shared<phi::DenseTensor>(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
meta);
dense_cpu_tensor->ResizeAndAllocate(dense_gpu_tensor->dims());
cpu_tensors[i] = paddle::experimental::Tensor(dense_cpu_tensor);
framework::TensorCopySync(*dense_gpu_tensor, platform::CPUPlace(), framework::TensorCopySync(*dense_gpu_tensor, platform::CPUPlace(),
dense_cpu_tensor.get()); dense_cpu_tensor.get());
} }
// Step3: do inter cluster allreduce // Step3: do inter cluster allreduce
if (with_switch_) { if (with_switch_) {
// TODO(sandyhouse) send to and recv from switch, and do add if (local_rank_ == 0) {
HeterClient* client_ =
HeterClient::GetInstance({switch_endpoint_}, {}, 0).get();
auto dense_cpu_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(cpu_tensors[0].impl());
std::vector<int> send_size;
send_size.push_back(dense_cpu_tensor->numel());
int ret = client_->Send(
gid_, {dense_cpu_tensor->name()}, send_size,
dense_cpu_tensor->data(),
dense_cpu_tensor->numel() *
framework::DataTypeSize(dense_cpu_tensor->dtype()));
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"Send to the switch module error."));
phi::DenseTensorMeta meta = phi::DenseTensorMeta(
dense_cpu_tensor->dtype(), dense_cpu_tensor->dims());
std::shared_ptr<phi::DenseTensor> dense_cpu_tensor2 =
std::make_shared<phi::DenseTensor>(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
meta);
dense_cpu_tensor2->ResizeAndAllocate(dense_cpu_tensor->dims());
Tensor cpu_tensor_temp =
paddle::experimental::Tensor(dense_cpu_tensor2);
ret = client_->Recv(
gid_, {dense_cpu_tensor->name()}, dense_cpu_tensor2->data(),
dense_cpu_tensor2->numel() *
framework::DataTypeSize(dense_cpu_tensor2->dtype()));
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"Recv from the switch module error."));
switch (dense_cpu_tensor->dtype()) {
case DataType::FLOAT32:
_do_add<float>(reinterpret_cast<float*>(dense_cpu_tensor->data()),
reinterpret_cast<float*>(dense_cpu_tensor2->data()),
dense_cpu_tensor->numel());
break;
case DataType::FLOAT64:
_do_add<double>(
reinterpret_cast<double*>(dense_cpu_tensor->data()),
reinterpret_cast<double*>(dense_cpu_tensor2->data()),
dense_cpu_tensor->numel());
break;
case DataType::INT32:
_do_add<int>(reinterpret_cast<int*>(dense_cpu_tensor->data()),
reinterpret_cast<int*>(dense_cpu_tensor2->data()),
dense_cpu_tensor->numel());
break;
default:
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Unsupported data type (%s) to do add.",
framework::DataType2String(dense_cpu_tensor->dtype())));
}
}
} else { } else {
auto gloo_task = inter_pg_->AllReduce(cpu_tensors, opts); auto gloo_task = inter_pg_->AllReduce(cpu_tensors, opts);
gloo_task->Wait(); gloo_task->Wait();
} }
// Step4: copy cpu tensors to gpu // Step4: copy cpu tensors to gpu
// TODO(sandyhouse)
// copy cpu tensors to gpu // copy cpu tensors to gpu
for (size_t i = 0; i < tensors.size(); i++) { for (size_t i = 0; i < tensors.size(); i++) {
auto dense_gpu_tensor = auto dense_gpu_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl());
auto dense_cpu_tensor = auto dense_cpu_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(cpu_tensors[i].impl()); std::dynamic_pointer_cast<phi::DenseTensor>(cpu_tensors[i].impl());
// framework::TensorCopySync(*dense_cpu_tensor, tensors[i].place(),
// dense_gpu_tensor.get());
framework::TensorCopySync(*dense_cpu_tensor, dense_cpu_tensor->place(), framework::TensorCopySync(*dense_cpu_tensor, dense_cpu_tensor->place(),
dense_gpu_tensor.get()); dense_gpu_tensor.get());
} }
...@@ -147,18 +212,57 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast( ...@@ -147,18 +212,57 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast(
inner_pg_->Broadcast(tensors, b_opts); inner_pg_->Broadcast(tensors, b_opts);
if (local_rank_ == 0) { if (local_rank_ == 0) {
std::vector<Tensor> cpu_tensors(tensors.size()); std::vector<Tensor> cpu_tensors;
cpu_tensors.reserve(tensors.size());
for (size_t i = 0; i < tensors.size(); i++) { for (size_t i = 0; i < tensors.size(); i++) {
auto dense_gpu_tensor = auto dense_gpu_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl());
auto dense_cpu_tensor = phi::DenseTensorMeta meta = phi::DenseTensorMeta(
std::dynamic_pointer_cast<phi::DenseTensor>(cpu_tensors[i].impl()); dense_gpu_tensor->dtype(), dense_gpu_tensor->dims());
dense_cpu_tensor->Resize(tensors[i].dims()); std::shared_ptr<phi::DenseTensor> dense_cpu_tensor =
std::make_shared<phi::DenseTensor>(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
meta);
dense_cpu_tensor->ResizeAndAllocate(dense_gpu_tensor->dims());
cpu_tensors[i] = paddle::experimental::Tensor(dense_cpu_tensor);
framework::TensorCopySync(*dense_gpu_tensor, platform::CPUPlace(), framework::TensorCopySync(*dense_gpu_tensor, platform::CPUPlace(),
dense_cpu_tensor.get()); dense_cpu_tensor.get());
} }
if (with_switch_) { if (with_switch_) {
// TODO(sandyhouse) send to and recv if (local_rank_ == 0) {
HeterClient* client_ =
HeterClient::GetInstance({switch_endpoint_}, {}, 0).get();
auto dense_cpu_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(cpu_tensors[0].impl());
if (gloo_rank_ == 0) {
std::vector<int> send_size;
send_size.push_back(dense_cpu_tensor->numel());
int ret = client_->Send(
gid_, {dense_cpu_tensor->name()}, send_size,
dense_cpu_tensor->data(),
dense_cpu_tensor->numel() *
framework::DataTypeSize(dense_cpu_tensor->dtype()));
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"Send to the switch module error."));
} else {
int ret = client_->Recv(
gid_, {dense_cpu_tensor->name()}, dense_cpu_tensor->data(),
dense_cpu_tensor->numel() *
framework::DataTypeSize(dense_cpu_tensor->dtype()));
PADDLE_ENFORCE_EQ(ret, 0,
platform::errors::PreconditionNotMet(
"Receive from the switch module error."));
ret = client_->Recv(
gid_, {dense_cpu_tensor->name()}, dense_cpu_tensor->data(),
dense_cpu_tensor->numel() *
framework::DataTypeSize(dense_cpu_tensor->dtype()));
PADDLE_ENFORCE_EQ(ret, 0,
platform::errors::PreconditionNotMet(
"Receive from the switch module error."));
}
}
} else { } else {
auto gloo_task = inter_pg_->Broadcast(cpu_tensors, opts); auto gloo_task = inter_pg_->Broadcast(cpu_tensors, opts);
gloo_task->Wait(); gloo_task->Wait();
...@@ -168,8 +272,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast( ...@@ -168,8 +272,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHeter::Broadcast(
std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl()); std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl());
auto dense_cpu_tensor = auto dense_cpu_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(cpu_tensors[i].impl()); std::dynamic_pointer_cast<phi::DenseTensor>(cpu_tensors[i].impl());
// framework::TensorCopySync(*dense_cpu_tensor, tensors[i].place(),
// dense_gpu_tensor.get());
framework::TensorCopySync(*dense_cpu_tensor, dense_cpu_tensor->place(), framework::TensorCopySync(*dense_cpu_tensor, dense_cpu_tensor->place(),
dense_gpu_tensor.get()); dense_gpu_tensor.get());
} }
...@@ -185,22 +287,44 @@ void ProcessGroupHeter::Broadcast(const phi::DenseTensor* in, ...@@ -185,22 +287,44 @@ void ProcessGroupHeter::Broadcast(const phi::DenseTensor* in,
inner_pg_->Broadcast(in, out); inner_pg_->Broadcast(in, out);
if (local_rank_ == 0) { if (local_rank_ == 0) {
Tensor cpu_tensor; phi::DenseTensorMeta meta = phi::DenseTensorMeta(in->dtype(), in->dims());
auto dense_cpu_tensor = std::shared_ptr<phi::DenseTensor> dense_cpu_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(cpu_tensor.impl()); std::make_shared<phi::DenseTensor>(
dense_cpu_tensor->Resize(in->dims()); std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
meta);
dense_cpu_tensor->ResizeAndAllocate(in->dims());
Tensor cpu_tensor = paddle::experimental::Tensor(dense_cpu_tensor);
framework::TensorCopySync(*in, platform::CPUPlace(), framework::TensorCopySync(*in, platform::CPUPlace(),
dense_cpu_tensor.get()); dense_cpu_tensor.get());
if (with_switch_) { if (with_switch_) {
// TODO(sandyhouse) send to and recv if (local_rank_ == 0) {
HeterClient* client_ =
HeterClient::GetInstance({switch_endpoint_}, {}, 0).get();
if (gloo_rank_ == 0) {
std::vector<int> send_size;
send_size.push_back(in->numel());
int ret = client_->Send(
gid_, {in->name()}, send_size, dense_cpu_tensor->data(),
in->numel() * framework::DataTypeSize(in->dtype()));
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"Send to the switch module error."));
} else {
int ret =
client_->Recv(gid_, {in->name()}, dense_cpu_tensor->data(),
in->numel() * framework::DataTypeSize(in->dtype()));
PADDLE_ENFORCE_EQ(ret, 0,
platform::errors::PreconditionNotMet(
"Receive from the switch module error."));
}
}
} else { } else {
std::vector<Tensor> cpu_tensors = {cpu_tensor}; std::vector<Tensor> cpu_tensors = {cpu_tensor};
// auto gloo_task = inter_pg_->Broadcast(cpu_tensors); auto gloo_task = inter_pg_->Broadcast(cpu_tensors);
// gloo_task->Wait(); gloo_task->Wait();
inter_pg_->Broadcast(cpu_tensors);
} }
framework::TensorCopySync(*dense_cpu_tensor, dense_cpu_tensor->place(), framework::TensorCopySync(*dense_cpu_tensor, out->place(), out);
out);
} }
inner_pg_->Broadcast(out, out); inner_pg_->Broadcast(out, out);
} }
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
#include "paddle/fluid/distributed/collective/ProcessGroup.h" #include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/distributed/collective/ProcessGroupGloo.h" #include "paddle/fluid/distributed/collective/ProcessGroupGloo.h"
// #include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_GLOO #ifdef PADDLE_WITH_GLOO
...@@ -48,6 +47,11 @@ ...@@ -48,6 +47,11 @@
#include "paddle/fluid/distributed/collective/ProcessGroupHCCL.h" #include "paddle/fluid/distributed/collective/ProcessGroupHCCL.h"
#endif #endif
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
(defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_ASCEND_CL))
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#endif
#include "paddle/fluid/distributed/collective/Common.h" #include "paddle/fluid/distributed/collective/Common.h"
constexpr const char* HETER_BACKEND_NAME = "HETER_BACKEND"; constexpr const char* HETER_BACKEND_NAME = "HETER_BACKEND";
...@@ -108,6 +112,7 @@ class ProcessGroupHeter : public ProcessGroup { ...@@ -108,6 +112,7 @@ class ProcessGroupHeter : public ProcessGroup {
int gloo_rank_; int gloo_rank_;
int gloo_size_; int gloo_size_;
bool with_switch_; bool with_switch_;
std::string switch_endpoint_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -226,6 +226,43 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( ...@@ -226,6 +226,43 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
return task; return task;
} }
template <typename Fn>
void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
phi::DenseTensor* out, Fn fn,
CommType op_type) {
std::vector<Place> places;
places.push_back(in->place());
const auto key = GetKeyFromPlaces(places);
{
std::lock_guard<std::mutex> lock(mutex_);
if (places_to_ncclcomm_.find(key) == places_to_ncclcomm_.end()) {
CreateNCCLManagerCache(key, places);
}
}
auto& nccl_comms = places_to_ncclcomm_[key];
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
// construct uninitialize guard for device
platform::CUDADeviceGuard cuda_guard;
if (FLAGS_use_stream_safe_cuda_allocator) {
cuda_guard.SetDevice(places[0]);
memory::RecordStream(in->Holder(), places_to_ctx_[key][0]->stream());
}
{
platform::NCCLGroupGuard nccl_guard;
cuda_guard.SetDevice(places[0]);
const auto& nccl_stream = places_to_ctx_[key][0]->stream();
fn(in, out, nccl_comms[0]->GetNcclComm(), nccl_stream);
}
cuda_guard.SetDevice(places[0]);
}
template <typename Fn> template <typename Fn>
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
std::vector<Tensor>& tensors, Fn fn, int dst_rank, CommType op_type) { std::vector<Tensor>& tensors, Fn fn, int dst_rank, CommType op_type) {
......
...@@ -146,6 +146,10 @@ class ProcessGroupNCCL : public ProcessGroup { ...@@ -146,6 +146,10 @@ class ProcessGroupNCCL : public ProcessGroup {
std::vector<Tensor>& outputs, // NOLINT std::vector<Tensor>& outputs, // NOLINT
Fn fn, CommType op_type); Fn fn, CommType op_type);
template <typename Fn>
void Collective(const phi::DenseTensor*, phi::DenseTensor*, Fn fn,
CommType op_type);
template <typename Fn> template <typename Fn>
std::shared_ptr<ProcessGroup::Task> PointToPoint( std::shared_ptr<ProcessGroup::Task> PointToPoint(
std::vector<Tensor>& tensors, // NOLINT std::vector<Tensor>& tensors, // NOLINT
......
...@@ -37,7 +37,6 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> { ...@@ -37,7 +37,6 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
int rid = ctx.Attr<int>("ring_id"); int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
auto map = distributed::ProcessGroupMapFromGid::getInstance(); auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) { if (map->has(rid)) {
// Use ProcessGroup // Use ProcessGroup
...@@ -46,6 +45,7 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> { ...@@ -46,6 +45,7 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
return; return;
} }
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
......
...@@ -91,12 +91,18 @@ if(NOT ON_INFER) ...@@ -91,12 +91,18 @@ if(NOT ON_INFER)
set (PYBIND_DEPS ${PYBIND_DEPS} processgroup eager_reducer) set (PYBIND_DEPS ${PYBIND_DEPS} processgroup eager_reducer)
if (WITH_NCCL) if (WITH_NCCL)
set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_nccl) set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_nccl)
if (WITH_PSCORE)
set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_heter)
endif()
endif() endif()
if (WITH_GLOO) if (WITH_GLOO)
set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_gloo) set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_gloo)
endif() endif()
if(WITH_ASCEND_CL) if(WITH_ASCEND_CL)
set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_hccl) set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_hccl)
if (WITH_PSCORE)
set (PYBIND_DEPS ${PYBIND_DEPS} processgroup_heter)
endif()
endif() endif()
set(PYBIND_SRCS ${PYBIND_SRCS} distributed_py.cc) set(PYBIND_SRCS ${PYBIND_SRCS} distributed_py.cc)
endif() endif()
......
...@@ -39,6 +39,11 @@ limitations under the License. */ ...@@ -39,6 +39,11 @@ limitations under the License. */
#include "paddle/fluid/distributed/collective/ProcessGroupHCCL.h" #include "paddle/fluid/distributed/collective/ProcessGroupHCCL.h"
#endif #endif
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_PSCORE) && \
(defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_ASCEND_CL))
#include "paddle/fluid/distributed/collective/ProcessGroupHeter.h"
#endif
#if defined(PADDLE_WITH_GLOO) #if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/distributed/collective/ProcessGroupGloo.h" #include "paddle/fluid/distributed/collective/ProcessGroupGloo.h"
#include "paddle/fluid/distributed/store/tcp_store.h" #include "paddle/fluid/distributed/store/tcp_store.h"
...@@ -217,6 +222,21 @@ void BindDistributed(py::module *m) { ...@@ -217,6 +222,21 @@ void BindDistributed(py::module *m) {
int>(), int>(),
py::arg("store"), py::arg("rank"), py::arg("world_size"), py::arg("store"), py::arg("rank"), py::arg("world_size"),
py::arg("group_id") = 0, py::call_guard<py::gil_scoped_release>()); py::arg("group_id") = 0, py::call_guard<py::gil_scoped_release>());
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_PSCORE) && \
(defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_ASCEND_CL))
py::class_<distributed::ProcessGroupHeter,
std::shared_ptr<distributed::ProcessGroupHeter>>(
*m, "ProcessGroupHeter", ProcessGroup)
.def(py::init<const std::shared_ptr<distributed::Store> &, int, int, int,
int, int, int, int, bool, std::string>(),
py::arg("store"), py::arg("rank"), py::arg("world_size"),
py::arg("gid") = 0, py::arg("local_rank") = 0,
py::arg("local_size") = 1, py::arg("gloo_rank") = 0,
py::arg("gloo_size") = 1, py::arg("with_switch") = false,
py::arg("switch_endpoint") = "",
py::call_guard<py::gil_scoped_release>());
#endif
#endif #endif
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
...@@ -227,6 +247,21 @@ void BindDistributed(py::module *m) { ...@@ -227,6 +247,21 @@ void BindDistributed(py::module *m) {
int>(), int>(),
py::arg("store"), py::arg("rank"), py::arg("world_size"), py::arg("store"), py::arg("rank"), py::arg("world_size"),
py::arg("group_id") = 0, py::call_guard<py::gil_scoped_release>()); py::arg("group_id") = 0, py::call_guard<py::gil_scoped_release>());
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_PSCORE) && \
(defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_ASCEND_CL))
py::class_<distributed::ProcessGroupHeter,
std::shared_ptr<distributed::ProcessGroupHeter>>(
*m, "ProcessGroupHeter", ProcessGroup)
.def(py::init<const std::shared_ptr<distributed::Store> &, int, int, int,
int, int, int, int, bool, std::string>(),
py::arg("store"), py::arg("rank"), py::arg("world_size"),
py::arg("gid") = 0, py::arg("local_rank") = 0,
py::arg("local_size") = 1, py::arg("gloo_rank") = 0,
py::arg("gloo_rank") = 1, py::arg("with_switch") = false,
py::arg("switch_endpoint") = "",
py::call_guard<py::gil_scoped_release>());
#endif
#endif #endif
py::class_<distributed::ProcessGroup::Task, py::class_<distributed::ProcessGroup::Task,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册