提交 f385228f 编写于 作者: Y Yu Yang

Add Paddle Enforce

上级 833e522d
...@@ -34,7 +34,7 @@ std::string OpHandleBase::DebugString() const { ...@@ -34,7 +34,7 @@ std::string OpHandleBase::DebugString() const {
OpHandleBase::~OpHandleBase() { OpHandleBase::~OpHandleBase() {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
for (auto &ev : events_) { for (auto &ev : events_) {
cudaEventDestroy(ev.second); PADDLE_ENFORCE(cudaEventDestroy(ev.second));
} }
#endif #endif
} }
...@@ -44,8 +44,9 @@ void OpHandleBase::Run(bool use_event) { ...@@ -44,8 +44,9 @@ void OpHandleBase::Run(bool use_event) {
if (events_.empty() && use_event) { if (events_.empty() && use_event) {
for (auto &p : dev_ctx_) { for (auto &p : dev_ctx_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device; int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
cudaSetDevice(dev_id); PADDLE_ENFORCE(cudaSetDevice(dev_id));
cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming); PADDLE_ENFORCE(
cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming));
} }
} }
#else #else
...@@ -60,7 +61,7 @@ void OpHandleBase::Run(bool use_event) { ...@@ -60,7 +61,7 @@ void OpHandleBase::Run(bool use_event) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device; int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
auto stream = auto stream =
static_cast<platform::CUDADeviceContext *>(p.second)->stream(); static_cast<platform::CUDADeviceContext *>(p.second)->stream();
cudaEventRecord(events_.at(dev_id), stream); PADDLE_ENFORCE(cudaEventRecord(events_.at(dev_id), stream));
} }
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册