From a9c8bdad7ba46f1a0a8af3ecf68483b3d8fae6e7 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 14 Oct 2019 10:01:07 +0800 Subject: [PATCH] refine pe codes, test=develop (#20479) --- .../fluid/framework/details/op_handle_base.cc | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index b2fa31f73b9..cbaaa91d1ea 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -188,26 +188,16 @@ bool OpHandleBase::NeedWait(VarHandleBase *in_var) { } void OpHandleBase::RunAndRecordEvent(const std::function &callback) { + callback(); #ifdef PADDLE_WITH_CUDA if (!events_.empty()) { // Use event - std::function method = callback; for (auto &p : dev_ctxes_) { - method = [method, p, this]() { - VLOG(10) << "cudadevicecontext:" - << static_cast(p.second) - << ", dev_id:" - << boost::get(p.first).device; - - static_cast(p.second)->RecordEvent( - events_.at(boost::get(p.first).device), - method); - }; + auto dev_id = boost::get(p.first).device; + auto *cuda_dev_ctx = static_cast(p.second); + VLOG(10) << "cudadevicecontext:" << cuda_dev_ctx << ", dev_id:" << dev_id; + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaEventRecord(events_.at(dev_id), cuda_dev_ctx->stream())); } - method(); - } else { -#endif - callback(); -#ifdef PADDLE_WITH_CUDA } #endif } -- GitLab