未验证 提交 925c17ab 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #9895 from reyoung/feature/fix_transformer_hang

Fix Transformer Hang Problem
...@@ -35,7 +35,9 @@ void ComputationOpHandle::RunImpl() { ...@@ -35,7 +35,9 @@ void ComputationOpHandle::RunImpl() {
} }
} }
this->RunAndRecordEvent([this] {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_); op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
});
} }
std::string ComputationOpHandle::Name() const { return op_->Type(); } std::string ComputationOpHandle::Name() const { return op_->Type(); }
......
...@@ -69,10 +69,12 @@ void NCCLAllReduceOpHandle::RunImpl() { ...@@ -69,10 +69,12 @@ void NCCLAllReduceOpHandle::RunImpl() {
}); });
} }
this->RunAndRecordEvent([&] {
platform::NCCLGroupGuard guard; platform::NCCLGroupGuard guard;
for (auto &call : all_reduce_calls) { for (auto &call : all_reduce_calls) {
call(); call();
} }
});
} }
} }
......
...@@ -54,17 +54,6 @@ void OpHandleBase::Run(bool use_event) { ...@@ -54,17 +54,6 @@ void OpHandleBase::Run(bool use_event) {
#endif #endif
RunImpl(); RunImpl();
#ifdef PADDLE_WITH_CUDA
if (use_event) {
for (auto &p : dev_ctxes_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
auto stream =
static_cast<platform::CUDADeviceContext *>(p.second)->stream();
PADDLE_ENFORCE(cudaEventRecord(events_.at(dev_id), stream));
}
}
#endif
} }
void OpHandleBase::Wait(platform::DeviceContext *waited_dev) { void OpHandleBase::Wait(platform::DeviceContext *waited_dev) {
...@@ -97,6 +86,27 @@ void OpHandleBase::AddOutput(VarHandleBase *out) { ...@@ -97,6 +86,27 @@ void OpHandleBase::AddOutput(VarHandleBase *out) {
out->generated_op_ = this; out->generated_op_ = this;
} }
void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA
if (!events_.empty()) { // Use event
std::function<void()> method = callback;
for (auto &p : dev_ctxes_) {
method = [method, p, this]() {
static_cast<platform::CUDADeviceContext *>(p.second)->RecordEvent(
events_.at(boost::get<platform::CUDAPlace>(p.first).device),
method);
};
}
method();
} else {
#endif
callback();
#ifdef PADDLE_WITH_CUDA
}
#endif
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -62,6 +62,8 @@ class OpHandleBase { ...@@ -62,6 +62,8 @@ class OpHandleBase {
virtual bool IsMultiDeviceTransfer() { return false; } virtual bool IsMultiDeviceTransfer() { return false; }
protected: protected:
void RunAndRecordEvent(const std::function<void()> &callback);
virtual void RunImpl() = 0; virtual void RunImpl() = 0;
}; };
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include <string>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -37,11 +39,13 @@ void ScaleLossGradOpHandle::RunImpl() { ...@@ -37,11 +39,13 @@ void ScaleLossGradOpHandle::RunImpl() {
*tmp = coeff_; *tmp = coeff_;
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
this->RunAndRecordEvent([&] {
auto stream = auto stream =
static_cast<platform::CUDADeviceContext *>(this->dev_ctxes_[place_]) static_cast<platform::CUDADeviceContext *>(this->dev_ctxes_[place_])
->stream(); ->stream();
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp, memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
platform::CPUPlace(), &coeff_, sizeof(float), stream); platform::CPUPlace(), &coeff_, sizeof(float), stream);
});
#endif #endif
} }
} }
......
...@@ -34,7 +34,7 @@ void SendOpHandle::RunImpl() { ...@@ -34,7 +34,7 @@ void SendOpHandle::RunImpl() {
} }
in->generated_op_->Wait(dev_ctxes_[p]); in->generated_op_->Wait(dev_ctxes_[p]);
} }
op_->Run(*local_scope_, place_); this->RunAndRecordEvent([&] { op_->Run(*local_scope_, place_); });
} }
std::string SendOpHandle::Name() const { return "send"; } std::string SendOpHandle::Name() const { return "send"; }
......
...@@ -196,10 +196,12 @@ void ThreadedSSAGraphExecutor::RunOp( ...@@ -196,10 +196,12 @@ void ThreadedSSAGraphExecutor::RunOp(
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) { BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] { auto op_run = [ready_var_q, op, this] {
try { try {
VLOG(10) << op->Name() << " : " << op->DebugString(); VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
op->Run(use_event_); op->Run(use_event_);
VLOG(10) << op << " " << op->Name() << " Done ";
running_ops_--; running_ops_--;
ready_var_q->Extend(op->outputs_); ready_var_q->Extend(op->outputs_);
VLOG(10) << op << " " << op->Name() << "Signal posted";
} catch (platform::EnforceNotMet ex) { } catch (platform::EnforceNotMet ex) {
exception_.reset(new platform::EnforceNotMet(ex)); exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) { } catch (...) {
......
...@@ -175,7 +175,7 @@ CUDADeviceContext::~CUDADeviceContext() { ...@@ -175,7 +175,7 @@ CUDADeviceContext::~CUDADeviceContext() {
Place CUDADeviceContext::GetPlace() const { return place_; } Place CUDADeviceContext::GetPlace() const { return place_; }
void CUDADeviceContext::Wait() const { void CUDADeviceContext::Wait() const {
std::lock_guard<std::mutex> guard(mutex_); std::lock_guard<std::recursive_mutex> guard(mutex_);
PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaGetLastError()); PADDLE_ENFORCE(cudaGetLastError());
} }
......
...@@ -98,13 +98,20 @@ class CUDADeviceContext : public DeviceContext { ...@@ -98,13 +98,20 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cuda stream in the device context. */ /*! \brief Return cuda stream in the device context. */
cudaStream_t stream() const; cudaStream_t stream() const;
template <typename Callback>
void RecordEvent(cudaEvent_t ev, Callback callback) {
std::lock_guard<std::recursive_mutex> guard(mutex_);
callback();
PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
}
private: private:
CUDAPlace place_; CUDAPlace place_;
std::unique_ptr<Eigen::GpuDevice> eigen_device_; std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
mutable std::mutex mutex_; mutable std::recursive_mutex mutex_;
cudaStream_t stream_; cudaStream_t stream_;
cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
cublasHandle_t cublas_handle_; cublasHandle_t cublas_handle_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册