diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 2e13b3c8c1cf9c6fd75bb0b34aa5b81858049e41..dc614fc6ba4ac604c1b0ef56a130c27d1d055492 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -124,12 +124,17 @@ struct ScaleLossGradOpHandle : public OpHandle { float coeff_; Scope *scope_; platform::Place place_; + cudaEvent_t ev_; explicit ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place) : coeff_(static_cast(1.0 / num_dev)), scope_(scope), - place_(place) {} + place_(place) { + PADDLE_ENFORCE(cudaEventCreateWithFlags(&ev_, cudaEventDisableTiming)); + } + + ~ScaleLossGradOpHandle() { PADDLE_ENFORCE(cudaEventDestroy(ev_)); } void Run() override { std::string var_name = static_cast(this->outputs_[0])->name_; @@ -141,16 +146,23 @@ struct ScaleLossGradOpHandle : public OpHandle { if (platform::is_cpu_place(place_)) { *tmp = coeff_; } else { - memory::Copy( - boost::get(place_), tmp, platform::CPUPlace(), - &coeff_, sizeof(float), + auto stream = static_cast(this->dev_ctx_[place_]) - ->stream()); + ->stream(); + memory::Copy(boost::get(place_), tmp, + platform::CPUPlace(), &coeff_, sizeof(float), stream); + PADDLE_ENFORCE(cudaEventRecord(ev_, stream)); } } void Wait(platform::DeviceContext *waited_dev) override { - this->dev_ctx_.at(place_)->Wait(); + if (platform::is_cpu_place(waited_dev->GetPlace())) { + this->dev_ctx_.at(place_)->Wait(); + } else { + auto stream = + static_cast(waited_dev)->stream(); + PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev_, 0)); + } } };