From 79be06045c2cfd97b14991dac5bdbe2a2fa765db Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 13 Apr 2018 16:43:44 +0800 Subject: [PATCH] Support CPU/GPU mixture for ParallelExecutor --- .../details/nccl_all_reduce_op_handle.cc | 13 +++++++++++++ paddle/fluid/framework/details/op_handle_base.cc | 16 ++++++++++++++++ paddle/fluid/framework/details/op_handle_base.h | 3 +++ 3 files changed, 32 insertions(+) diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc index 3547a6e21c7..1e48f75958a 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc @@ -116,6 +116,19 @@ void NCCLAllReduceOpHandle::RunImpl() { // Reduce All Tensor to trg in CPU ReduceLoDTensor func(lod_tensors, &trg); VisitDataType(ToDataType(lod_tensors[0].type()), func); + + for (size_t i = 0; i < local_scopes_.size(); ++i) { + auto &scope = local_scopes_[i]; + auto &p = places_[i]; + auto *var = scope->FindVar(var_name); + auto *dev_ctx = dev_ctxes_[p]; + + RunAndRecordEvent(p, [&trg, var, dev_ctx, p] { + auto &tensor_gpu = *var->GetMutable(); + auto &tensor_cpu = trg; + TensorCopy(tensor_cpu, p, *dev_ctx, &tensor_gpu); + }); + } } } } diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index 846bc21be27..28f1e7b5088 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -107,6 +107,22 @@ void OpHandleBase::RunAndRecordEvent(const std::function &callback) { #endif } +void OpHandleBase::RunAndRecordEvent(platform::Place p, + const std::function &callback) { + if (platform::is_cpu_place(p) || events_.empty()) { + callback(); + } else { +#ifdef PADDLE_WITH_CUDA + auto *ctx = dev_ctxes_.at(p); + auto *cuda_ctx = static_cast(ctx); + cuda_ctx->RecordEvent(events_.at(boost::get(p).device), + callback); +#else + PADDLE_THROW("Not implemented"); +#endif + } +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 1aacba5a4c3..a9a6c8d39cf 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -64,6 +64,9 @@ class OpHandleBase { protected: void RunAndRecordEvent(const std::function &callback); + void RunAndRecordEvent(platform::Place p, + const std::function &callback); + virtual void RunImpl() = 0; }; -- GitLab