diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index ac4d1f58a5b3b11f034af7618681ebd913d8afb9..9406c6155da860c90739bddac1e81403b094e619 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -145,9 +145,9 @@ void ParallelExecutor::BCastParamsToGPUs( auto &dims = main_tensor.dims(); if (paddle::platform::is_gpu_place(main_tensor.place())) { #ifdef PADDLE_WITH_CUDA + std::vector buffers; size_t numel = main_tensor.numel(); ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type()); - platform::NCCLGroupGuard guard; for (size_t i = 0; i < member_->places_.size(); ++i) { auto place = member_->places_[i]; void *buffer; @@ -159,11 +159,21 @@ void ParallelExecutor::BCastParamsToGPUs( t->Resize(dims); buffer = t->mutable_data(place, main_tensor.type()); } - auto &nccl_ctx = member_->nccl_ctxs_->at(place); - platform::dynload::ncclBcast(buffer, numel, data_type, 0, - nccl_ctx.comm_, nccl_ctx.stream()); + buffers.push_back(buffer); } - member_->nccl_ctxs_->WaitAll(); + + PADDLE_ENFORCE_EQ(member_->places_.size(), buffers.size(), + "variables' buffer size to bcast NOT equal to places"); + { + platform::NCCLGroupGuard guard; + for (size_t i = 0; i < member_->places_.size(); ++i) { + auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]); + platform::dynload::ncclBcast(buffers[i], numel, data_type, 0, + nccl_ctx.comm_, nccl_ctx.stream()); + } + member_->nccl_ctxs_->WaitAll(); + } + #else PADDLE_THROW("Not compiled with CUDA"); #endif diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index 6f8e3f22db54d166cf97cfdd3d009058207a7ca5..cc46c88fd1f9a5d1bacad26beed6fd0af6405310 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -41,6 +41,11 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) { } } +// NOTE(minqiyang): according to the ncclGroupEnd documentations: +// https://docs.nvidia.com/deeplearning/sdk/nccl-api/ncclapidoc.html, +// ncclGroupEnd will wait for all communicators to be initialized, which will +// cause blocking problem when a runtime_error was thrown, so try only guard +// NCCL actions when use it. class NCCLGroupGuard { public: static std::mutex &NCCLMutex() {