From 76b321872a0c449d4f082cf437672e1d9e7510cb Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 14 Oct 2019 09:40:04 +0800 Subject: [PATCH] fix cuda dev_ctx by event, test=develop (#20553) --- .../details/sparse_all_reduce_op_handle.cc | 31 ++++++++++++++++--- .../details/sparse_all_reduce_op_handle.h | 2 ++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc b/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc index 070a17a9de5..e69bda6fcf8 100644 --- a/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/details/sparse_all_reduce_op_handle.h" #include +#include #include "dgc/dgc.h" #include "paddle/fluid/framework/details/container_cast.h" #include "paddle/fluid/framework/details/reduce_and_gather.h" @@ -41,11 +42,30 @@ SparseAllReduceOpHandle::SparseAllReduceOpHandle( } } +void SparseAllReduceOpHandle::WaitInputVarGenerated() { +#ifdef PADDLE_WITH_CUDA + for (auto &p : dev_ctxes_) { + if (platform::is_gpu_place(p.first)) { + int dev_id = boost::get(p.first).device; + auto *compute_dev_ctx = + platform::DeviceContextPool::Instance().GetByPlace( + platform::CUDAPlace(dev_id)); + auto *dev_ctx = static_cast(p.second); + if (compute_dev_ctx->stream() != dev_ctx->stream()) { + auto &event = events_.at(dev_id); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaEventRecord(event, compute_dev_ctx->stream())); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamWaitEvent(dev_ctx->stream(), event, 0)); + } + } + } +#endif +} + void SparseAllReduceOpHandle::RunImplEncoded() { platform::RecordEvent record_event(Name()); - WaitInputVarGenerated(); - auto in_var_handles = DynamicCast(this->Inputs()); auto out_var_handles = DynamicCast(this->Outputs()); PADDLE_ENFORCE_EQ( @@ -87,6 +107,8 @@ void SparseAllReduceOpHandle::RunImplEncoded() { PADDLE_ENFORCE(nranks_ > 1); std::vector> all_reduce_calls; + std::vector allocations; + for (size_t i = 0; i < local_scopes_.size(); ++i) { auto &place = places_[i]; auto &in = *ins[i]; @@ -104,7 +126,6 @@ void SparseAllReduceOpHandle::RunImplEncoded() { int dev_id = boost::get(place).device; auto *nccl_ctxs = nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, false); auto &nccl_ctx = nccl_ctxs->at(dev_id); - auto *dev_ctx = nccl_ctxs->DevCtx(dev_id); auto stream = nccl_ctx.stream(); auto comm = nccl_ctx.comm_; @@ -112,8 +133,9 @@ void SparseAllReduceOpHandle::RunImplEncoded() { // dgc use ncclAllGather to get all the encoded data // so the buffer need nranks. int buf_size = nranks_ * encode_size; - auto tmp_ious_data = memory::Alloc(*dev_ctx, buf_size); + auto tmp_ious_data = memory::Alloc(place, buf_size); void *gather_buff = reinterpret_cast(tmp_ious_data->ptr()); + allocations.emplace_back(std::move(tmp_ious_data)); VLOG(10) << "in_numel:" << in_numel << ", out_numel:" << out_numel << ", nranks:" << nranks_ << ", gather_buf size:" << buf_size @@ -126,6 +148,7 @@ void SparseAllReduceOpHandle::RunImplEncoded() { }); } + WaitInputVarGenerated(); NCCLAllReduceFunc(all_reduce_calls); } diff --git a/paddle/fluid/framework/details/sparse_all_reduce_op_handle.h b/paddle/fluid/framework/details/sparse_all_reduce_op_handle.h index 9802f8dba7e..d15814a2197 100644 --- a/paddle/fluid/framework/details/sparse_all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/sparse_all_reduce_op_handle.h @@ -36,6 +36,8 @@ class SparseAllReduceOpHandle : public AllReduceOpHandle { bool is_encoded = false, int nranks = -1); std::string Name() const override; + void WaitInputVarGenerated() override; + protected: void RunImpl() override; int GetKValue(const std::string &grad_name); -- GitLab