提交 c15d2c9e 编写于 作者: Y Yu Yang

Update

上级 d470763f
...@@ -171,27 +171,28 @@ class ParallelExecutorPrivate { ...@@ -171,27 +171,28 @@ class ParallelExecutorPrivate {
return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device; return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device;
} }
static void InitNCCLContext(std::map<int, NCCLContext> &contexts) { static void InitNCCLContext(std::unordered_map<int, NCCLContext> &contexts,
const std::vector<platform::Place> &places) {
std::vector<ncclComm_t> comms; std::vector<ncclComm_t> comms;
std::vector<int> devs; std::vector<int> devs;
comms.resize(contexts.size()); comms.resize(contexts.size());
devs.reserve(contexts.size()); devs.reserve(contexts.size());
for (auto &ctx : contexts) { for (auto &p : places) {
devs.push_back(ctx.first); devs.push_back(boost::get<platform::CUDAPlace>(p).device);
} }
NCCL_INVOKE(platform::dynload::ncclCommInitAll( NCCL_INVOKE(platform::dynload::ncclCommInitAll(
&comms[0], static_cast<int>(contexts.size()), &devs[0])); &comms[0], static_cast<int>(contexts.size()), &devs[0]));
int i = 0; int i = 0;
for (auto &ctx : contexts) { for (auto &dev_id : devs) {
ctx.second.comm = comms[i++]; contexts.at(dev_id).comm = comms[i++];
} }
} }
}; };
std::map<int, NCCLContext> communication_streams_; std::unordered_map<int, NCCLContext> communication_streams_;
NCCLContext &GetNCCLCtx(platform::Place p) { NCCLContext &GetNCCLCtx(platform::Place p) {
int dev_id = boost::get<platform::CUDAPlace>(p).device; int dev_id = boost::get<platform::CUDAPlace>(p).device;
...@@ -493,13 +494,20 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -493,13 +494,20 @@ void ParallelExecutor::BCastParamsToGPUs(
platform::dynload::ncclGroupStart(); platform::dynload::ncclGroupStart();
for (auto &place : member_->places_) { for (size_t i = 0; i < member_->places_.size(); ++i) {
auto local_scope = member_->local_scopes_[place]; auto place = member_->places_[i];
auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>(); void *buffer;
t->Resize(dims); if (i == 0) {
buffer = const_cast<void *>(main_tensor.data<void>());
} else {
auto local_scope = member_->local_scopes_[place];
auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>();
t->Resize(dims);
buffer = t->mutable_data(place, main_tensor.type());
}
auto &nccl_ctx = member_->GetNCCLCtx(place); auto &nccl_ctx = member_->GetNCCLCtx(place);
platform::dynload::ncclBcast(t->mutable_data(place, main_tensor.type()), platform::dynload::ncclBcast(buffer, numel, data_type, 0, nccl_ctx.comm,
numel, data_type, 0, nccl_ctx.comm,
nccl_ctx.stream()); nccl_ctx.stream());
} }
platform::dynload::ncclGroupEnd(); platform::dynload::ncclGroupEnd();
...@@ -533,7 +541,7 @@ void ParallelExecutor::BuildNCCLCommunicator() const { ...@@ -533,7 +541,7 @@ void ParallelExecutor::BuildNCCLCommunicator() const {
} }
ParallelExecutorPrivate::NCCLContext::InitNCCLContext( ParallelExecutorPrivate::NCCLContext::InitNCCLContext(
member_->communication_streams_); member_->communication_streams_, member_->places_);
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册