提交 d470763f 编写于 作者: Y Yu Yang

Stash

上级 9fc0b596
...@@ -154,6 +154,8 @@ class ParallelExecutorPrivate { ...@@ -154,6 +154,8 @@ class ParallelExecutorPrivate {
std::unordered_map<platform::Place, Scope *, platform::PlaceHash> std::unordered_map<platform::Place, Scope *, platform::PlaceHash>
local_scopes_; local_scopes_;
std::vector<platform::Place> places_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
struct NCCLContext { struct NCCLContext {
std::unique_ptr<platform::CUDADeviceContext> ctx_; std::unique_ptr<platform::CUDADeviceContext> ctx_;
...@@ -246,6 +248,8 @@ ParallelExecutor::ParallelExecutor( ...@@ -246,6 +248,8 @@ ParallelExecutor::ParallelExecutor(
const ProgramDesc &startup_program, const ProgramDesc &main_program, const ProgramDesc &startup_program, const ProgramDesc &main_program,
const std::string &loss_var_name, Scope *scope) const std::string &loss_var_name, Scope *scope)
: member_(new ParallelExecutorPrivate()) { : member_(new ParallelExecutorPrivate()) {
member_->places_ = places;
// Step 1. RunStartupProgram and Bcast the params to devs. // Step 1. RunStartupProgram and Bcast the params to devs.
Executor exe(places[0]); Executor exe(places[0]);
exe.Run(startup_program, scope, 0); exe.Run(startup_program, scope, 0);
...@@ -489,14 +493,14 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -489,14 +493,14 @@ void ParallelExecutor::BCastParamsToGPUs(
platform::dynload::ncclGroupStart(); platform::dynload::ncclGroupStart();
for (auto &pair : member_->local_scopes_) { for (auto &place : member_->places_) {
auto local_scope = pair.second; auto local_scope = member_->local_scopes_[place];
auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>(); auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>();
t->Resize(dims); t->Resize(dims);
auto &nccl_ctx = member_->GetNCCLCtx(pair.first); auto &nccl_ctx = member_->GetNCCLCtx(place);
platform::dynload::ncclBcast( platform::dynload::ncclBcast(t->mutable_data(place, main_tensor.type()),
t->mutable_data(pair.first, main_tensor.type()), numel, data_type, numel, data_type, 0, nccl_ctx.comm,
0, nccl_ctx.comm, nccl_ctx.stream()); nccl_ctx.stream());
} }
platform::dynload::ncclGroupEnd(); platform::dynload::ncclGroupEnd();
} }
...@@ -506,7 +510,7 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -506,7 +510,7 @@ void ParallelExecutor::BCastParamsToGPUs(
for (auto &pair : member_->local_scopes_) { for (auto &pair : member_->local_scopes_) {
member_->GetNCCLCtx(pair.first).ctx_->Wait(); member_->GetNCCLCtx(pair.first).ctx_->Wait();
auto &b = pair.second->FindVar("fc_1.b_0")->Get<framework::LoDTensor>(); auto &b = pair.second->FindVar("fc_0.b_0")->Get<framework::LoDTensor>();
framework::LoDTensor cpu; framework::LoDTensor cpu;
framework::TensorCopy(b, platform::CPUPlace(), &cpu); framework::TensorCopy(b, platform::CPUPlace(), &cpu);
platform::DeviceContextPool::Instance().Get(b.place())->Wait(); platform::DeviceContextPool::Instance().Get(b.place())->Wait();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册