diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 3a92494e7e918e49355cc60583ada1bf2b24be29..f841b3b7fa84c5a69fe510b5fc239d949c2212b5 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -92,12 +92,22 @@ struct ComputationOpHandle : public OpHandle { std::unique_ptr op_; Scope *scope_; platform::Place place_; + cudaEvent_t event_; explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope, platform::Place place) : op_(framework::OpRegistry::CreateOp(op_desc)), scope_(scope), - place_(place) {} + place_(place) { + if (platform::is_gpu_place(place)) { + cudaSetDevice(boost::get(place_).device); + cudaEventCreateWithFlags(&event_, cudaEventDisableTiming); + } + } + + ~ComputationOpHandle() { + // FIXME: Destroy Event + } void Run() override { // Wait other op if necessary @@ -113,10 +123,22 @@ struct ComputationOpHandle : public OpHandle { } op_->Run(*scope_, place_); + if (platform::is_gpu_place(place_)) { + auto stream = static_cast(dev_ctx_[place_]) + ->stream(); + PADDLE_ENFORCE(cudaEventRecord(event_, stream)); + } } void Wait(platform::DeviceContext *waited_dev) override { - this->dev_ctx_.at(place_)->Wait(); + if (platform::is_cpu_place(waited_dev->GetPlace()) || + platform::is_cpu_place(place_)) { + this->dev_ctx_.at(place_)->Wait(); + } else { + auto stream = + static_cast(waited_dev)->stream(); + PADDLE_ENFORCE(cudaStreamWaitEvent(stream, event_, 0)); + } } };