diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 5870eac8115a6882bd6bce269377f8bf64849df4..d46adf291b76cad390fb10821da82c547a7a7b37 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -77,7 +77,7 @@ struct OpHandle { virtual ~OpHandle() {} virtual void Run() { PADDLE_THROW("Not implemented"); } - virtual void Wait() {} + virtual void Wait(platform::DeviceContext *waited_dev) {} }; struct ComputationOpHandle : public OpHandle { @@ -97,13 +97,17 @@ struct ComputationOpHandle : public OpHandle { auto *cur_ctx = dev_ctx_[place_]; for (auto *in : inputs_) { if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) { - in->generated_op_->Wait(); + in->generated_op_->Wait(cur_ctx); } } op_->Run(*scope_, place_); LOG(INFO) << "Done " << this; } + + void Wait(platform::DeviceContext *waited_dev) override { + this->dev_ctx_.at(place_)->Wait(); + } }; struct ScaleLossGradOpHandle : public OpHandle { @@ -136,6 +140,10 @@ struct ScaleLossGradOpHandle : public OpHandle { ->stream()); } } + + void Wait(platform::DeviceContext *waited_dev) override { + this->dev_ctx_.at(place_)->Wait(); + } }; class ParallelExecutorPrivate { @@ -276,6 +284,10 @@ struct NCCLAllReduceOpHandle : public OpHandle { platform::dynload::ncclGroupEnd(); } } + + void Wait(platform::DeviceContext *waited_dev) override { + this->dev_ctx_.at(waited_dev->GetPlace())->Wait(); + } }; ParallelExecutor::ParallelExecutor(