From c372ce2885684f9d4af26e2e894d70c33e5d4cc8 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 19 Mar 2018 20:54:55 +0800 Subject: [PATCH] Add event for computational op --- paddle/fluid/framework/parallel_executor.cc | 26 +++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 3a92494e7e..f841b3b7fa 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -92,12 +92,22 @@ struct ComputationOpHandle : public OpHandle { std::unique_ptr op_; Scope *scope_; platform::Place place_; + cudaEvent_t event_; explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope, platform::Place place) : op_(framework::OpRegistry::CreateOp(op_desc)), scope_(scope), - place_(place) {} + place_(place) { + if (platform::is_gpu_place(place)) { + cudaSetDevice(boost::get(place_).device); + cudaEventCreateWithFlags(&event_, cudaEventDisableTiming); + } + } + + ~ComputationOpHandle() { + // FIXME: Destroy Event + } void Run() override { // Wait other op if necessary @@ -113,10 +123,22 @@ struct ComputationOpHandle : public OpHandle { } op_->Run(*scope_, place_); + if (platform::is_gpu_place(place_)) { + auto stream = static_cast(dev_ctx_[place_]) + ->stream(); + PADDLE_ENFORCE(cudaEventRecord(event_, stream)); + } } void Wait(platform::DeviceContext *waited_dev) override { - this->dev_ctx_.at(place_)->Wait(); + if (platform::is_cpu_place(waited_dev->GetPlace()) || + platform::is_cpu_place(place_)) { + this->dev_ctx_.at(place_)->Wait(); + } else { + auto stream = + static_cast(waited_dev)->stream(); + PADDLE_ENFORCE(cudaStreamWaitEvent(stream, event_, 0)); + } } }; -- GitLab