diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 8a90f231d741b01532fe3c18e11c54648d97f868..91f2db9354c2a00ec7e51ea4595c7cfa00da23ea 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/parallel_executor.h" +#include #include "ThreadPool.h" @@ -102,30 +103,43 @@ void ParallelExecutor::BCastParamsToGPUs( auto *main_scope = member_->local_scopes_[0]; for (auto *var_desc : startup_program.Block(0).AllVars()) { + size_t idx = var_desc->Name().find("@GRAD"); + if (idx != std::string::npos) continue; if (var_desc->GetType() == proto::VarType::LOD_TENSOR) { auto &main_tensor = main_scope->FindVar(var_desc->Name())->Get(); - ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type()); - auto &dims = main_tensor.dims(); - size_t numel = main_tensor.numel(); - platform::NCCLGroupGuard guard; + auto &dims = main_tensor.dims(); - for (size_t i = 0; i < member_->places_.size(); ++i) { - auto place = member_->places_[i]; - void *buffer; - if (i == 0) { - buffer = const_cast(main_tensor.data()); - } else { + if (paddle::platform::is_gpu_place(main_tensor.place())) { + size_t numel = main_tensor.numel(); + ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type()); + platform::NCCLGroupGuard guard; + for (size_t i = 0; i < member_->places_.size(); ++i) { + auto place = member_->places_[i]; + void *buffer; + if (i == 0) { + buffer = const_cast(main_tensor.data()); + } else { + auto local_scope = member_->local_scopes_[i]; + auto *t = + local_scope->Var(var_desc->Name())->GetMutable(); + t->Resize(dims); + buffer = t->mutable_data(place, main_tensor.type()); + } + auto &nccl_ctx = member_->nccl_ctxs_->at(place); + platform::dynload::ncclBcast(buffer, numel, data_type, 0, + nccl_ctx.comm_, nccl_ctx.stream()); + } + } else { + platform::CPUPlace cpu; + for (size_t i = 1; i < member_->places_.size(); ++i) { auto local_scope = member_->local_scopes_[i]; auto *t = local_scope->Var(var_desc->Name())->GetMutable(); t->Resize(dims); - buffer = t->mutable_data(place, main_tensor.type()); + t->mutable_data(cpu, main_tensor.type()); + paddle::framework::TensorCopy(main_tensor, cpu, t); } - - auto &nccl_ctx = member_->nccl_ctxs_->at(place); - platform::dynload::ncclBcast(buffer, numel, data_type, 0, - nccl_ctx.comm_, nccl_ctx.stream()); } } member_->nccl_ctxs_->WaitAll();