diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 3b09c8402b885c8790ef410fd424b3cc832e377d..7db8d003b408b59c26d6af81833d491ad8623364 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -10,7 +10,7 @@ cc_library(engine SRCS engine.cc DEPS layer gradient_accumulator) cc_library(imperative_profiler SRCS profiler.cc) if(NOT WIN32) if(WITH_NCCL) - cc_library(nccl_context SRCS nccl_context.cc DEPS device_context) + cc_library(nccl_context SRCS nccl_context.cc DEPS collective_helper device_context) endif() cc_library(data_loader SRCS data_loader.cc DEPS enforce) endif(NOT WIN32) diff --git a/paddle/fluid/imperative/nccl_context.cc b/paddle/fluid/imperative/nccl_context.cc index 3c3634b0bd804a7da0cefd3ffb72533d66546492..bc71140058934036c76dc50bd498a9fb99258952 100644 --- a/paddle/fluid/imperative/nccl_context.cc +++ b/paddle/fluid/imperative/nccl_context.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/imperative/nccl_context.h" +#include "paddle/fluid/platform/collective_helper.h" namespace paddle { namespace imperative { @@ -115,7 +116,6 @@ void NCCLParallelContext::BcastNCCLId(ncclUniqueId *nccl_id, int root) { void NCCLParallelContext::Init() { ncclUniqueId nccl_id; - ncclComm_t comm; if (strategy_.local_rank_ == 0) { // generate the unique ncclid on the root worker platform::dynload::ncclGetUniqueId(&nccl_id); @@ -128,12 +128,13 @@ void NCCLParallelContext::Init() { << " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id; PADDLE_ENFORCE(cudaSetDevice(gpu_id)); - PADDLE_ENFORCE(platform::dynload::ncclCommInitRank( - &comm, strategy_.nranks_, nccl_id, strategy_.local_rank_)); + platform::NCCLComm *nccl_comm = + platform::NCCLCommContext::Instance().CreateNCCLComm( + &nccl_id, strategy_.nranks_, strategy_.local_rank_, gpu_id, 0); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto *dev_ctx = static_cast(pool.Get(place_)); - dev_ctx->set_nccl_comm(comm); + dev_ctx->set_nccl_comm(nccl_comm->comm()); } #endif diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 08ed66542a72fe1e853dd24c7b9bf4e16423253e..810b9e86b0c44b0fb4877e53e3dc88af1d570a22 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -78,9 +78,7 @@ cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} ${dgc_deps} dlpack cudnn_workspace_helper) -if (WITH_DISTRIBUTE) - cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce) -endif() +cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce) if(WIN32) if(WITH_GPU AND NOT WITH_DSO)