From 6b20b35589c3443bbd49fde2b71b5c4e0e5b8cc0 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 13 Apr 2018 15:22:04 +0800 Subject: [PATCH] Fix Transformer Hang Problem --- .../details/computation_op_handle.cc | 4 ++- .../details/nccl_all_reduce_op_handle.cc | 10 +++--- .../fluid/framework/details/op_handle_base.cc | 32 ++++++++++++------- .../fluid/framework/details/op_handle_base.h | 2 ++ .../details/scale_loss_grad_op_handle.cc | 14 +++++--- .../fluid/framework/details/send_op_handle.cc | 2 +- .../details/threaded_ssa_graph_executor.cc | 4 ++- paddle/fluid/platform/device_context.cc | 2 +- paddle/fluid/platform/device_context.h | 9 +++++- 9 files changed, 54 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index e3f8bbb72f2..ff6d91c1daf 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -35,7 +35,9 @@ void ComputationOpHandle::RunImpl() { } } - op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get(), place_); + this->RunAndRecordEvent([this] { + op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get(), place_); + }); } std::string ComputationOpHandle::Name() const { return op_->Type(); } diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc index 55b5f113589..0611ec6376d 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc @@ -69,10 +69,12 @@ void NCCLAllReduceOpHandle::RunImpl() { }); } - platform::NCCLGroupGuard guard; - for (auto &call : all_reduce_calls) { - call(); - } + this->RunAndRecordEvent([&] { + platform::NCCLGroupGuard guard; + for (auto &call : all_reduce_calls) { + call(); + } + }); } } diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index e4194a7442f..846bc21be27 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -54,17 +54,6 @@ void OpHandleBase::Run(bool use_event) { #endif RunImpl(); - -#ifdef PADDLE_WITH_CUDA - if (use_event) { - for (auto &p : dev_ctxes_) { - int dev_id = boost::get(p.first).device; - auto stream = - static_cast(p.second)->stream(); - PADDLE_ENFORCE(cudaEventRecord(events_.at(dev_id), stream)); - } - } -#endif } void OpHandleBase::Wait(platform::DeviceContext *waited_dev) { @@ -97,6 +86,27 @@ void OpHandleBase::AddOutput(VarHandleBase *out) { out->generated_op_ = this; } +void OpHandleBase::RunAndRecordEvent(const std::function &callback) { +#ifdef PADDLE_WITH_CUDA + if (!events_.empty()) { // Use event + std::function method = callback; + + for (auto &p : dev_ctxes_) { + method = [method, p, this]() { + static_cast(p.second)->RecordEvent( + events_.at(boost::get(p.first).device), + method); + }; + } + method(); + } else { +#endif + callback(); +#ifdef PADDLE_WITH_CUDA + } +#endif +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index fbdb54ba8d9..1aacba5a4c3 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -62,6 +62,8 @@ class OpHandleBase { virtual bool IsMultiDeviceTransfer() { return false; } protected: + void RunAndRecordEvent(const std::function &callback); + virtual void RunImpl() = 0; }; diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc index 0a6f6129b81..7fb9f99a8a1 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc @@ -14,6 +14,8 @@ #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" +#include + namespace paddle { namespace framework { namespace details { @@ -37,11 +39,13 @@ void ScaleLossGradOpHandle::RunImpl() { *tmp = coeff_; } else { #ifdef PADDLE_WITH_CUDA - auto stream = - static_cast(this->dev_ctxes_[place_]) - ->stream(); - memory::Copy(boost::get(place_), tmp, - platform::CPUPlace(), &coeff_, sizeof(float), stream); + this->RunAndRecordEvent([&] { + auto stream = + static_cast(this->dev_ctxes_[place_]) + ->stream(); + memory::Copy(boost::get(place_), tmp, + platform::CPUPlace(), &coeff_, sizeof(float), stream); + }); #endif } } diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/send_op_handle.cc index d181607e863..549b9d9abbe 100644 --- a/paddle/fluid/framework/details/send_op_handle.cc +++ b/paddle/fluid/framework/details/send_op_handle.cc @@ -34,7 +34,7 @@ void SendOpHandle::RunImpl() { } in->generated_op_->Wait(dev_ctxes_[p]); } - op_->Run(*local_scope_, place_); + this->RunAndRecordEvent([&] { op_->Run(*local_scope_, place_); }); } std::string SendOpHandle::Name() const { return "send"; } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 1ce69ab02b0..a371ee10fe0 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -196,10 +196,12 @@ void ThreadedSSAGraphExecutor::RunOp( BlockingQueue *ready_var_q, details::OpHandleBase *op) { auto op_run = [ready_var_q, op, this] { try { - VLOG(10) << op->Name() << " : " << op->DebugString(); + VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); op->Run(use_event_); + VLOG(10) << op << " " << op->Name() << " Done "; running_ops_--; ready_var_q->Extend(op->outputs_); + VLOG(10) << op << " " << op->Name() << "Signal posted"; } catch (platform::EnforceNotMet ex) { exception_.reset(new platform::EnforceNotMet(ex)); } catch (...) { diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index f03165fae5c..1f733d71bdf 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -175,7 +175,7 @@ CUDADeviceContext::~CUDADeviceContext() { Place CUDADeviceContext::GetPlace() const { return place_; } void CUDADeviceContext::Wait() const { - std::lock_guard guard(mutex_); + std::lock_guard guard(mutex_); PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); PADDLE_ENFORCE(cudaGetLastError()); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index b1755833791..a9c1984616b 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -98,13 +98,20 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cuda stream in the device context. */ cudaStream_t stream() const; + template + void RecordEvent(cudaEvent_t ev, Callback callback) { + std::lock_guard guard(mutex_); + callback(); + PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); + } + private: CUDAPlace place_; std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_; - mutable std::mutex mutex_; + mutable std::recursive_mutex mutex_; cudaStream_t stream_; cudnnHandle_t cudnn_handle_; cublasHandle_t cublas_handle_; -- GitLab