未验证 提交 dc82fa96 编写于 作者: L LiYuRio 提交者: GitHub

Use string as the unique_key of comm_context_manager (#55726)

* use string as key for comm_context_manager

* remove device_id from comm_context
上级 82ebe9b9
...@@ -694,7 +694,7 @@ std::shared_ptr<ProcessGroupGloo> ProcessGroupGloo::CreateProcessGroupGloo( ...@@ -694,7 +694,7 @@ std::shared_ptr<ProcessGroupGloo> ProcessGroupGloo::CreateProcessGroupGloo(
opts->device = ProcessGroupGloo::createDefaultDevice(); opts->device = ProcessGroupGloo::createDefaultDevice();
} }
phi::distributed::CommContextManager::CreateGlooCommContext( phi::distributed::CommContextManager::CreateGlooCommContext(
store, gid, rank, size); store, std::to_string(gid), rank, size);
auto process_group = auto process_group =
std::make_shared<ProcessGroupGloo>(store, rank, size, gid, opts); std::make_shared<ProcessGroupGloo>(store, rank, size, gid, opts);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group); ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
...@@ -705,7 +705,7 @@ phi::distributed::GlooCommContext* ProcessGroupGloo::GetCommContext() { ...@@ -705,7 +705,7 @@ phi::distributed::GlooCommContext* ProcessGroupGloo::GetCommContext() {
const auto& comm_context_manager = const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance(); phi::distributed::CommContextManager::GetInstance();
auto comm_context = static_cast<phi::distributed::GlooCommContext*>( auto comm_context = static_cast<phi::distributed::GlooCommContext*>(
comm_context_manager.Get(this->gid_)); comm_context_manager.Get(std::to_string(this->gid_)));
PADDLE_ENFORCE_NE(comm_context, PADDLE_ENFORCE_NE(comm_context,
nullptr, nullptr,
phi::errors::Unavailable("GlooCommContext is nullptr")); phi::errors::Unavailable("GlooCommContext is nullptr"));
......
...@@ -497,6 +497,9 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, ...@@ -497,6 +497,9 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_ VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
<< ", place: " << place_key; << ", place: " << place_key;
phi::distributed::CommContextManager::CreateNCCLCommContext(
store_, std::to_string(gid_), rank_, size_);
auto* calc_ctx = static_cast<phi::GPUContext*>( auto* calc_ctx = static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
auto comm_ctx = std::make_unique<phi::GPUContext>(place); auto comm_ctx = std::make_unique<phi::GPUContext>(place);
...@@ -980,12 +983,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -980,12 +983,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
std::shared_ptr<ProcessGroupNCCL> ProcessGroupNCCL::CreateProcessGroupNCCL( std::shared_ptr<ProcessGroupNCCL> ProcessGroupNCCL::CreateProcessGroupNCCL(
const std::shared_ptr<phi::distributed::Store>& store, const std::shared_ptr<phi::distributed::Store>& store,
int device_id,
int rank, int rank,
int size, int size,
int gid) { int gid) {
phi::distributed::CommContextManager::CreateNCCLCommContext(
store, device_id, gid, rank, size);
auto process_group = auto process_group =
std::make_shared<ProcessGroupNCCL>(store, rank, size, gid); std::make_shared<ProcessGroupNCCL>(store, rank, size, gid);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group); ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
...@@ -996,7 +996,7 @@ phi::distributed::NCCLCommContext* ProcessGroupNCCL::GetCommContext() { ...@@ -996,7 +996,7 @@ phi::distributed::NCCLCommContext* ProcessGroupNCCL::GetCommContext() {
const auto& comm_context_manager = const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance(); phi::distributed::CommContextManager::GetInstance();
auto comm_context = static_cast<phi::distributed::NCCLCommContext*>( auto comm_context = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(this->gid_)); comm_context_manager.Get(std::to_string(this->gid_)));
PADDLE_ENFORCE_NE(comm_context, PADDLE_ENFORCE_NE(comm_context,
nullptr, nullptr,
phi::errors::Unavailable("NCCLCommContext is nullptr")); phi::errors::Unavailable("NCCLCommContext is nullptr"));
......
...@@ -69,7 +69,6 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { ...@@ -69,7 +69,6 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
public: public:
static std::shared_ptr<ProcessGroupNCCL> CreateProcessGroupNCCL( static std::shared_ptr<ProcessGroupNCCL> CreateProcessGroupNCCL(
const std::shared_ptr<phi::distributed::Store>& store, const std::shared_ptr<phi::distributed::Store>& store,
int device_id,
int rank, int rank,
int size, int size,
int gid); int gid);
......
...@@ -1136,8 +1136,8 @@ void SetDeviceCommContext(framework::OperatorBase* operator_base, ...@@ -1136,8 +1136,8 @@ void SetDeviceCommContext(framework::OperatorBase* operator_base,
int ring_id = operator_base->Attr<int>("ring_id"); int ring_id = operator_base->Attr<int>("ring_id");
const auto& comm_context_manager = const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance(); phi::distributed::CommContextManager::GetInstance();
if (comm_context_manager.Has(ring_id)) { if (comm_context_manager.Has(std::to_string(ring_id))) {
auto comm_context = comm_context_manager.Get(ring_id); auto comm_context = comm_context_manager.Get(std::to_string(ring_id));
if (!dev_ctx->GetCommContext()) { if (!dev_ctx->GetCommContext()) {
dev_ctx->SetCommContext(comm_context); dev_ctx->SetCommContext(comm_context);
} }
...@@ -1156,8 +1156,8 @@ void SetDeviceCommContext(::ir::Operation* op, ...@@ -1156,8 +1156,8 @@ void SetDeviceCommContext(::ir::Operation* op,
op_attributes.at("ring_id").dyn_cast<::ir::Int32Attribute>().data(); op_attributes.at("ring_id").dyn_cast<::ir::Int32Attribute>().data();
const auto& comm_context_manager = const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance(); phi::distributed::CommContextManager::GetInstance();
if (comm_context_manager.Has(ring_id)) { if (comm_context_manager.Has(std::to_string(ring_id))) {
auto comm_context = comm_context_manager.Get(ring_id); auto comm_context = comm_context_manager.Get(std::to_string(ring_id));
if (!dev_ctx->GetCommContext()) { if (!dev_ctx->GetCommContext()) {
dev_ctx->SetCommContext(comm_context); dev_ctx->SetCommContext(comm_context);
} }
......
...@@ -42,9 +42,9 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> { ...@@ -42,9 +42,9 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
gpuStream_t stream = ctx.cuda_device_context().stream(); gpuStream_t stream = ctx.cuda_device_context().stream();
const auto& comm_context_manager = const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance(); phi::distributed::CommContextManager::GetInstance();
if (comm_context_manager.Has(rid)) { if (comm_context_manager.Has(std::to_string(rid))) {
auto* comm_context = static_cast<phi::distributed::NCCLCommContext*>( auto* comm_context = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(rid)); comm_context_manager.Get(std::to_string(rid)));
comm_context->Broadcast(out, *x, root, stream); comm_context->Broadcast(out, *x, root, stream);
} else { } else {
......
...@@ -47,9 +47,9 @@ class CBroadcastOpCPUKernel : public framework::OpKernel<T> { ...@@ -47,9 +47,9 @@ class CBroadcastOpCPUKernel : public framework::OpKernel<T> {
const auto& comm_context_manager = const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance(); phi::distributed::CommContextManager::GetInstance();
if (comm_context_manager.Has(rid)) { if (comm_context_manager.Has(std::to_string(rid))) {
auto* comm_context = static_cast<phi::distributed::GlooCommContext*>( auto* comm_context = static_cast<phi::distributed::GlooCommContext*>(
comm_context_manager.Get(rid)); comm_context_manager.Get(std::to_string(rid)));
comm_context->Broadcast(out, *in, root); comm_context->Broadcast(out, *in, root);
} else { } else {
// NOTE: This will be removed after moving this operator to phi. // NOTE: This will be removed after moving this operator to phi.
......
...@@ -46,6 +46,9 @@ void BindCommContextManager(py::module *m) { ...@@ -46,6 +46,9 @@ void BindCommContextManager(py::module *m) {
"create_nccl_comm_context", "create_nccl_comm_context",
&phi::distributed::CommContextManager::CreateNCCLCommContext, &phi::distributed::CommContextManager::CreateNCCLCommContext,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def_static("set_cuda_device_id",
&phi::distributed::CommContextManager::SetCUDADeviceId,
py::call_guard<py::gil_scoped_release>())
#endif #endif
#if defined(PADDLE_WITH_GLOO) #if defined(PADDLE_WITH_GLOO)
.def_static( .def_static(
......
...@@ -1238,7 +1238,6 @@ void BindDistributed(py::module *m) { ...@@ -1238,7 +1238,6 @@ void BindDistributed(py::module *m) {
.def_static("create", .def_static("create",
distributed::ProcessGroupNCCL::CreateProcessGroupNCCL, distributed::ProcessGroupNCCL::CreateProcessGroupNCCL,
py::arg("store"), py::arg("store"),
py::arg("device_id"),
py::arg("rank"), py::arg("rank"),
py::arg("world_size"), py::arg("world_size"),
py::arg("group_id") = 0, py::arg("group_id") = 0,
......
...@@ -37,19 +37,21 @@ namespace phi { ...@@ -37,19 +37,21 @@ namespace phi {
namespace distributed { namespace distributed {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
void CommContextManager::SetCUDADeviceId(int dev_id) {
phi::backends::gpu::SetDeviceId(dev_id);
}
void CommContextManager::CreateNCCLCommContext( void CommContextManager::CreateNCCLCommContext(
const std::shared_ptr<Store>& store, const std::shared_ptr<Store>& store,
int dev_id, const std::string& unique_comm_key,
int ring_id,
int rank, int rank,
int size) { int size) {
phi::backends::gpu::SetDeviceId(dev_id);
ncclUniqueId nccl_id; ncclUniqueId nccl_id;
if (rank == 0) { if (rank == 0) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetUniqueId(&nccl_id)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetUniqueId(&nccl_id));
} }
std::string unique_key = "NCCLCommContext/" + std::to_string(ring_id); std::string unique_key = "NCCLCommContext/" + unique_comm_key;
if (rank == 0) { if (rank == 0) {
std::vector<uint8_t> nccl_id_wrapper( std::vector<uint8_t> nccl_id_wrapper(
reinterpret_cast<uint8_t*>(&nccl_id), reinterpret_cast<uint8_t*>(&nccl_id),
...@@ -64,16 +66,19 @@ void CommContextManager::CreateNCCLCommContext( ...@@ -64,16 +66,19 @@ void CommContextManager::CreateNCCLCommContext(
std::make_unique<NCCLCommContext>(rank, size, nccl_id); std::make_unique<NCCLCommContext>(rank, size, nccl_id);
auto& comm_context_manager = CommContextManager::GetInstance(); auto& comm_context_manager = CommContextManager::GetInstance();
comm_context_manager.SetStore(store); comm_context_manager.SetStore(store);
comm_context_manager.Emplace(ring_id, std::move(nccl_comm_context)); comm_context_manager.Emplace(unique_comm_key, std::move(nccl_comm_context));
} }
#endif #endif
#if defined(PADDLE_WITH_GLOO) #if defined(PADDLE_WITH_GLOO)
void CommContextManager::CreateGlooCommContext( void CommContextManager::CreateGlooCommContext(
const std::shared_ptr<Store>& store, int ring_id, int rank, int size) { const std::shared_ptr<Store>& store,
const std::string& unique_comm_key,
int rank,
int size) {
GlooStore store_wrapper(store); GlooStore store_wrapper(store);
auto gloo_store = std::make_shared<gloo::rendezvous::PrefixStore>( auto gloo_store = std::make_shared<gloo::rendezvous::PrefixStore>(
std::to_string(ring_id), store_wrapper); unique_comm_key, store_wrapper);
auto gloo_device = CreateGlooDevice(); auto gloo_device = CreateGlooDevice();
...@@ -82,31 +87,33 @@ void CommContextManager::CreateGlooCommContext( ...@@ -82,31 +87,33 @@ void CommContextManager::CreateGlooCommContext(
auto& comm_context_manager = CommContextManager::GetInstance(); auto& comm_context_manager = CommContextManager::GetInstance();
// set actual store to manager // set actual store to manager
comm_context_manager.SetStore(store); comm_context_manager.SetStore(store);
comm_context_manager.Emplace(ring_id, std::move(gloo_comm_context)); comm_context_manager.Emplace(unique_comm_key, std::move(gloo_comm_context));
} }
#endif #endif
CommContext* CommContextManager::Emplace( CommContext* CommContextManager::Emplace(
int ring_id, std::unique_ptr<CommContext> comm_context) { const std::string& unique_comm_key,
std::unique_ptr<CommContext> comm_context) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
id_to_comm_context_.find(ring_id), id_to_comm_context_.find(unique_comm_key),
id_to_comm_context_.end(), id_to_comm_context_.end(),
errors::AlreadyExists("Ring id %d already exists in the map.", ring_id)); errors::AlreadyExists("The unique key %s already exists in the map.",
id_to_comm_context_.emplace(ring_id, std::move(comm_context)); unique_comm_key));
return id_to_comm_context_.at(ring_id).get(); id_to_comm_context_.emplace(unique_comm_key, std::move(comm_context));
return id_to_comm_context_.at(unique_comm_key).get();
} }
CommContext* CommContextManager::Get(int ring_id) const { CommContext* CommContextManager::Get(const std::string& unique_comm_key) const {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
id_to_comm_context_.find(ring_id), id_to_comm_context_.find(unique_comm_key),
id_to_comm_context_.end(), id_to_comm_context_.end(),
errors::NotFound("Can not find ring id %d in map.", ring_id)); errors::NotFound("Can not find unique key %s in map.", unique_comm_key));
return id_to_comm_context_.at(ring_id).get(); return id_to_comm_context_.at(unique_comm_key).get();
} }
bool CommContextManager::Has(int ring_id) const { bool CommContextManager::Has(const std::string& unique_comm_key) const {
return id_to_comm_context_.find(ring_id) != id_to_comm_context_.end(); return id_to_comm_context_.find(unique_comm_key) != id_to_comm_context_.end();
} }
} // namespace distributed } // namespace distributed
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/phi/core/distributed/comm_context.h" #include "paddle/phi/core/distributed/comm_context.h"
...@@ -38,23 +39,25 @@ class CommContextManager { ...@@ -38,23 +39,25 @@ class CommContextManager {
void SetStore(const std::shared_ptr<Store>& store) { store_ = store; } void SetStore(const std::shared_ptr<Store>& store) { store_ = store; }
CommContext* Emplace(int ring_id, std::unique_ptr<CommContext> comm_context); CommContext* Emplace(const std::string& unique_comm_key,
std::unique_ptr<CommContext> comm_context);
CommContext* Get(int ring_id) const; CommContext* Get(const std::string& unique_comm_key) const;
bool Has(int ring_id) const; bool Has(const std::string& unique_comm_key) const;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
static void CreateNCCLCommContext(const std::shared_ptr<Store>& store, static void CreateNCCLCommContext(const std::shared_ptr<Store>& store,
int dev_id, const std::string& unique_comm_key,
int ring_id,
int rank, int rank,
int size); int size);
static void SetCUDADeviceId(int dev_id);
#endif #endif
#if defined(PADDLE_WITH_GLOO) #if defined(PADDLE_WITH_GLOO)
static void CreateGlooCommContext(const std::shared_ptr<Store>& store, static void CreateGlooCommContext(const std::shared_ptr<Store>& store,
int ring_id, const std::string& unique_comm_key,
int rank, int rank,
int size); int size);
#endif #endif
...@@ -62,7 +65,8 @@ class CommContextManager { ...@@ -62,7 +65,8 @@ class CommContextManager {
private: private:
DISABLE_COPY_AND_ASSIGN(CommContextManager); DISABLE_COPY_AND_ASSIGN(CommContextManager);
std::unordered_map<int, std::unique_ptr<CommContext>> id_to_comm_context_; std::unordered_map<std::string, std::unique_ptr<CommContext>>
id_to_comm_context_;
std::shared_ptr<Store> store_; std::shared_ptr<Store> store_;
}; };
......
...@@ -151,9 +151,7 @@ def _new_process_group_impl( ...@@ -151,9 +151,7 @@ def _new_process_group_impl(
if backend == "gloo": if backend == "gloo":
pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id) pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id)
elif backend == "nccl": elif backend == "nccl":
pg = core.ProcessGroupNCCL.create( pg = core.ProcessGroupNCCL.create(store, rank, world_size, group_id)
store, genv.device_id, rank, world_size, group_id
)
elif backend == "xccl": elif backend == "xccl":
pg = core.ProcessGroupCustom.create( pg = core.ProcessGroupCustom.create(
...@@ -344,9 +342,10 @@ def _init_parallel_env(backend): ...@@ -344,9 +342,10 @@ def _init_parallel_env(backend):
) )
if backend == "gloo": if backend == "gloo":
core.CommContextManager.create_gloo_comm_context( core.CommContextManager.create_gloo_comm_context(
store, 0, rank, world_size store, "0", rank, world_size
) )
elif backend == "nccl": elif backend == "nccl":
core.CommContextManager.set_cuda_device_id(dev_id)
core.CommContextManager.create_nccl_comm_context( core.CommContextManager.create_nccl_comm_context(
store, dev_id, 0, rank, world_size store, "0", rank, world_size
) )
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册