提交 6b20b355 编写于 作者: Y Yu Yang

Fix Transformer Hang Problem

上级 5a4d9328
......@@ -35,7 +35,9 @@ void ComputationOpHandle::RunImpl() {
}
}
this->RunAndRecordEvent([this] {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
});
}
std::string ComputationOpHandle::Name() const { return op_->Type(); }
......
......@@ -69,10 +69,12 @@ void NCCLAllReduceOpHandle::RunImpl() {
});
}
this->RunAndRecordEvent([&] {
platform::NCCLGroupGuard guard;
for (auto &call : all_reduce_calls) {
call();
}
});
}
}
......
......@@ -54,17 +54,6 @@ void OpHandleBase::Run(bool use_event) {
#endif
RunImpl();
#ifdef PADDLE_WITH_CUDA
if (use_event) {
for (auto &p : dev_ctxes_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
auto stream =
static_cast<platform::CUDADeviceContext *>(p.second)->stream();
PADDLE_ENFORCE(cudaEventRecord(events_.at(dev_id), stream));
}
}
#endif
}
void OpHandleBase::Wait(platform::DeviceContext *waited_dev) {
......@@ -97,6 +86,27 @@ void OpHandleBase::AddOutput(VarHandleBase *out) {
out->generated_op_ = this;
}
void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA
if (!events_.empty()) { // Use event
std::function<void()> method = callback;
for (auto &p : dev_ctxes_) {
method = [method, p, this]() {
static_cast<platform::CUDADeviceContext *>(p.second)->RecordEvent(
events_.at(boost::get<platform::CUDAPlace>(p.first).device),
method);
};
}
method();
} else {
#endif
callback();
#ifdef PADDLE_WITH_CUDA
}
#endif
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -62,6 +62,8 @@ class OpHandleBase {
virtual bool IsMultiDeviceTransfer() { return false; }
protected:
void RunAndRecordEvent(const std::function<void()> &callback);
virtual void RunImpl() = 0;
};
......
......@@ -14,6 +14,8 @@
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include <string>
namespace paddle {
namespace framework {
namespace details {
......@@ -37,11 +39,13 @@ void ScaleLossGradOpHandle::RunImpl() {
*tmp = coeff_;
} else {
#ifdef PADDLE_WITH_CUDA
this->RunAndRecordEvent([&] {
auto stream =
static_cast<platform::CUDADeviceContext *>(this->dev_ctxes_[place_])
->stream();
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
platform::CPUPlace(), &coeff_, sizeof(float), stream);
});
#endif
}
}
......
......@@ -34,7 +34,7 @@ void SendOpHandle::RunImpl() {
}
in->generated_op_->Wait(dev_ctxes_[p]);
}
op_->Run(*local_scope_, place_);
this->RunAndRecordEvent([&] { op_->Run(*local_scope_, place_); });
}
std::string SendOpHandle::Name() const { return "send"; }
......
......@@ -196,10 +196,12 @@ void ThreadedSSAGraphExecutor::RunOp(
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] {
try {
VLOG(10) << op->Name() << " : " << op->DebugString();
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
op->Run(use_event_);
VLOG(10) << op << " " << op->Name() << " Done ";
running_ops_--;
ready_var_q->Extend(op->outputs_);
VLOG(10) << op << " " << op->Name() << "Signal posted";
} catch (platform::EnforceNotMet ex) {
exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) {
......
......@@ -175,7 +175,7 @@ CUDADeviceContext::~CUDADeviceContext() {
Place CUDADeviceContext::GetPlace() const { return place_; }
void CUDADeviceContext::Wait() const {
std::lock_guard<std::mutex> guard(mutex_);
std::lock_guard<std::recursive_mutex> guard(mutex_);
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaGetLastError());
}
......
......@@ -98,13 +98,20 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cuda stream in the device context. */
cudaStream_t stream() const;
template <typename Callback>
void RecordEvent(cudaEvent_t ev, Callback callback) {
std::lock_guard<std::recursive_mutex> guard(mutex_);
callback();
PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
}
private:
CUDAPlace place_;
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
mutable std::mutex mutex_;
mutable std::recursive_mutex mutex_;
cudaStream_t stream_;
cudnnHandle_t cudnn_handle_;
cublasHandle_t cublas_handle_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册