From 121b2aed4de6c36d7d68dfd126c58081cfe6a114 Mon Sep 17 00:00:00 2001 From: Yi Liu Date: Wed, 18 Mar 2020 20:26:52 +0800 Subject: [PATCH] initialize global nccl context in dygraph (#23037) initialize global nccl context in dygraph test=develop --- paddle/fluid/imperative/CMakeLists.txt | 2 +- paddle/fluid/imperative/nccl_context.cc | 9 +++++---- paddle/fluid/platform/CMakeLists.txt | 4 +--- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 3b09c8402b..7db8d003b4 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -10,7 +10,7 @@ cc_library(engine SRCS engine.cc DEPS layer gradient_accumulator) cc_library(imperative_profiler SRCS profiler.cc) if(NOT WIN32) if(WITH_NCCL) - cc_library(nccl_context SRCS nccl_context.cc DEPS device_context) + cc_library(nccl_context SRCS nccl_context.cc DEPS collective_helper device_context) endif() cc_library(data_loader SRCS data_loader.cc DEPS enforce) endif(NOT WIN32) diff --git a/paddle/fluid/imperative/nccl_context.cc b/paddle/fluid/imperative/nccl_context.cc index 3c3634b0bd..bc71140058 100644 --- a/paddle/fluid/imperative/nccl_context.cc +++ b/paddle/fluid/imperative/nccl_context.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/imperative/nccl_context.h" +#include "paddle/fluid/platform/collective_helper.h" namespace paddle { namespace imperative { @@ -115,7 +116,6 @@ void NCCLParallelContext::BcastNCCLId(ncclUniqueId *nccl_id, int root) { void NCCLParallelContext::Init() { ncclUniqueId nccl_id; - ncclComm_t comm; if (strategy_.local_rank_ == 0) { // generate the unique ncclid on the root worker platform::dynload::ncclGetUniqueId(&nccl_id); @@ -128,12 +128,13 @@ void NCCLParallelContext::Init() { << " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id; PADDLE_ENFORCE(cudaSetDevice(gpu_id)); - PADDLE_ENFORCE(platform::dynload::ncclCommInitRank( - &comm, strategy_.nranks_, nccl_id, strategy_.local_rank_)); + platform::NCCLComm *nccl_comm = + platform::NCCLCommContext::Instance().CreateNCCLComm( + &nccl_id, strategy_.nranks_, strategy_.local_rank_, gpu_id, 0); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto *dev_ctx = static_cast(pool.Get(place_)); - dev_ctx->set_nccl_comm(comm); + dev_ctx->set_nccl_comm(nccl_comm->comm()); } #endif diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 08ed66542a..810b9e86b0 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -78,9 +78,7 @@ cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} ${dgc_deps} dlpack cudnn_workspace_helper) -if (WITH_DISTRIBUTE) - cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce) -endif() +cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce) if(WIN32) if(WITH_GPU AND NOT WITH_DSO) -- GitLab