From 60d0a0594e4cf0152459646f36fa71d3f454856f Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 28 Mar 2018 17:13:25 +0800 Subject: [PATCH] refine parallel --- paddle/fluid/framework/parallel_executor.cc | 44 ++++++++++++++------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 8a90f231d7..91f2db9354 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/parallel_executor.h" +#include #include "ThreadPool.h" @@ -102,30 +103,43 @@ void ParallelExecutor::BCastParamsToGPUs( auto *main_scope = member_->local_scopes_[0]; for (auto *var_desc : startup_program.Block(0).AllVars()) { + size_t idx = var_desc->Name().find("@GRAD"); + if (idx != std::string::npos) continue; if (var_desc->GetType() == proto::VarType::LOD_TENSOR) { auto &main_tensor = main_scope->FindVar(var_desc->Name())->Get(); - ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type()); - auto &dims = main_tensor.dims(); - size_t numel = main_tensor.numel(); - platform::NCCLGroupGuard guard; + auto &dims = main_tensor.dims(); - for (size_t i = 0; i < member_->places_.size(); ++i) { - auto place = member_->places_[i]; - void *buffer; - if (i == 0) { - buffer = const_cast(main_tensor.data()); - } else { + if (paddle::platform::is_gpu_place(main_tensor.place())) { + size_t numel = main_tensor.numel(); + ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type()); + platform::NCCLGroupGuard guard; + for (size_t i = 0; i < member_->places_.size(); ++i) { + auto place = member_->places_[i]; + void *buffer; + if (i == 0) { + buffer = const_cast(main_tensor.data()); + } else { + auto local_scope = member_->local_scopes_[i]; + auto *t = + local_scope->Var(var_desc->Name())->GetMutable(); + t->Resize(dims); + buffer = t->mutable_data(place, main_tensor.type()); + } + auto &nccl_ctx = member_->nccl_ctxs_->at(place); + platform::dynload::ncclBcast(buffer, numel, data_type, 0, + nccl_ctx.comm_, nccl_ctx.stream()); + } + } else { + platform::CPUPlace cpu; + for (size_t i = 1; i < member_->places_.size(); ++i) { auto local_scope = member_->local_scopes_[i]; auto *t = local_scope->Var(var_desc->Name())->GetMutable(); t->Resize(dims); - buffer = t->mutable_data(place, main_tensor.type()); + t->mutable_data(cpu, main_tensor.type()); + paddle::framework::TensorCopy(main_tensor, cpu, t); } - - auto &nccl_ctx = member_->nccl_ctxs_->at(place); - platform::dynload::ncclBcast(buffer, numel, data_type, 0, - nccl_ctx.comm_, nccl_ctx.stream()); } } member_->nccl_ctxs_->WaitAll(); -- GitLab