未验证 提交 121b2aed 编写于 作者: Y Yi Liu 提交者: GitHub

initialize global nccl context in dygraph (#23037)

initialize global nccl context in dygraph
test=develop
上级 5a202af9
...@@ -10,7 +10,7 @@ cc_library(engine SRCS engine.cc DEPS layer gradient_accumulator) ...@@ -10,7 +10,7 @@ cc_library(engine SRCS engine.cc DEPS layer gradient_accumulator)
cc_library(imperative_profiler SRCS profiler.cc) cc_library(imperative_profiler SRCS profiler.cc)
if(NOT WIN32) if(NOT WIN32)
if(WITH_NCCL) if(WITH_NCCL)
cc_library(nccl_context SRCS nccl_context.cc DEPS device_context) cc_library(nccl_context SRCS nccl_context.cc DEPS collective_helper device_context)
endif() endif()
cc_library(data_loader SRCS data_loader.cc DEPS enforce) cc_library(data_loader SRCS data_loader.cc DEPS enforce)
endif(NOT WIN32) endif(NOT WIN32)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/imperative/nccl_context.h" #include "paddle/fluid/imperative/nccl_context.h"
#include "paddle/fluid/platform/collective_helper.h"
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
...@@ -115,7 +116,6 @@ void NCCLParallelContext::BcastNCCLId(ncclUniqueId *nccl_id, int root) { ...@@ -115,7 +116,6 @@ void NCCLParallelContext::BcastNCCLId(ncclUniqueId *nccl_id, int root) {
void NCCLParallelContext::Init() { void NCCLParallelContext::Init() {
ncclUniqueId nccl_id; ncclUniqueId nccl_id;
ncclComm_t comm;
if (strategy_.local_rank_ == 0) { if (strategy_.local_rank_ == 0) {
// generate the unique ncclid on the root worker // generate the unique ncclid on the root worker
platform::dynload::ncclGetUniqueId(&nccl_id); platform::dynload::ncclGetUniqueId(&nccl_id);
...@@ -128,12 +128,13 @@ void NCCLParallelContext::Init() { ...@@ -128,12 +128,13 @@ void NCCLParallelContext::Init() {
<< " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id; << " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id;
PADDLE_ENFORCE(cudaSetDevice(gpu_id)); PADDLE_ENFORCE(cudaSetDevice(gpu_id));
PADDLE_ENFORCE(platform::dynload::ncclCommInitRank( platform::NCCLComm *nccl_comm =
&comm, strategy_.nranks_, nccl_id, strategy_.local_rank_)); platform::NCCLCommContext::Instance().CreateNCCLComm(
&nccl_id, strategy_.nranks_, strategy_.local_rank_, gpu_id, 0);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(pool.Get(place_)); auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(pool.Get(place_));
dev_ctx->set_nccl_comm(comm); dev_ctx->set_nccl_comm(nccl_comm->comm());
} }
#endif #endif
......
...@@ -78,9 +78,7 @@ cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool ...@@ -78,9 +78,7 @@ cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool
place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}
${dgc_deps} dlpack cudnn_workspace_helper) ${dgc_deps} dlpack cudnn_workspace_helper)
if (WITH_DISTRIBUTE) cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce)
cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce)
endif()
if(WIN32) if(WIN32)
if(WITH_GPU AND NOT WITH_DSO) if(WITH_GPU AND NOT WITH_DSO)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册