diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc index 3547a6e21c7770ef800a30a0971e93d0120fe245..1e48f75958a3ada4d1cd5c8d0f920da4fed2157e 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc @@ -116,6 +116,19 @@ void NCCLAllReduceOpHandle::RunImpl() { // Reduce All Tensor to trg in CPU ReduceLoDTensor func(lod_tensors, &trg); VisitDataType(ToDataType(lod_tensors[0].type()), func); + + for (size_t i = 0; i < local_scopes_.size(); ++i) { + auto &scope = local_scopes_[i]; + auto &p = places_[i]; + auto *var = scope->FindVar(var_name); + auto *dev_ctx = dev_ctxes_[p]; + + RunAndRecordEvent(p, [&trg, var, dev_ctx, p] { + auto &tensor_gpu = *var->GetMutable(); + auto &tensor_cpu = trg; + TensorCopy(tensor_cpu, p, *dev_ctx, &tensor_gpu); + }); + } } } } diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index 846bc21be27cf6c889ae34b967b8bff3c60ab743..28f1e7b5088d45cefa94c6ed7ca36ef8e3ce65d4 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -107,6 +107,22 @@ void OpHandleBase::RunAndRecordEvent(const std::function &callback) { #endif } +void OpHandleBase::RunAndRecordEvent(platform::Place p, + const std::function &callback) { + if (platform::is_cpu_place(p) || events_.empty()) { + callback(); + } else { +#ifdef PADDLE_WITH_CUDA + auto *ctx = dev_ctxes_.at(p); + auto *cuda_ctx = static_cast(ctx); + cuda_ctx->RecordEvent(events_.at(boost::get(p).device), + callback); +#else + PADDLE_THROW("Not implemented"); +#endif + } +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 1aacba5a4c3c959b6f584aa5f4dcdc5c0dc43e76..a9a6c8d39cf8741f7d9c91579a650ad742cec381 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -64,6 +64,9 @@ class OpHandleBase { protected: void RunAndRecordEvent(const std::function &callback); + void RunAndRecordEvent(platform::Place p, + const std::function &callback); + virtual void RunImpl() = 0; };