未验证 提交 121b2aed 编写于 作者: Y Yi Liu 提交者: GitHub

initialize global nccl context in dygraph (#23037)

initialize global nccl context in dygraph
test=develop
上级 5a202af9
......@@ -10,7 +10,7 @@ cc_library(engine SRCS engine.cc DEPS layer gradient_accumulator)
cc_library(imperative_profiler SRCS profiler.cc)
if(NOT WIN32)
if(WITH_NCCL)
cc_library(nccl_context SRCS nccl_context.cc DEPS device_context)
cc_library(nccl_context SRCS nccl_context.cc DEPS collective_helper device_context)
endif()
cc_library(data_loader SRCS data_loader.cc DEPS enforce)
endif(NOT WIN32)
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/imperative/nccl_context.h"
#include "paddle/fluid/platform/collective_helper.h"
namespace paddle {
namespace imperative {
......@@ -115,7 +116,6 @@ void NCCLParallelContext::BcastNCCLId(ncclUniqueId *nccl_id, int root) {
void NCCLParallelContext::Init() {
ncclUniqueId nccl_id;
ncclComm_t comm;
if (strategy_.local_rank_ == 0) {
// generate the unique ncclid on the root worker
platform::dynload::ncclGetUniqueId(&nccl_id);
......@@ -128,12 +128,13 @@ void NCCLParallelContext::Init() {
<< " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id;
PADDLE_ENFORCE(cudaSetDevice(gpu_id));
PADDLE_ENFORCE(platform::dynload::ncclCommInitRank(
&comm, strategy_.nranks_, nccl_id, strategy_.local_rank_));
platform::NCCLComm *nccl_comm =
platform::NCCLCommContext::Instance().CreateNCCLComm(
&nccl_id, strategy_.nranks_, strategy_.local_rank_, gpu_id, 0);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(pool.Get(place_));
dev_ctx->set_nccl_comm(comm);
dev_ctx->set_nccl_comm(nccl_comm->comm());
}
#endif
......
......@@ -78,9 +78,7 @@ cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool
place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}
${dgc_deps} dlpack cudnn_workspace_helper)
if (WITH_DISTRIBUTE)
cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce)
endif()
cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce)
if(WIN32)
if(WITH_GPU AND NOT WITH_DSO)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册