From b2c7a9b82850c2e4ffaf7027e82f49fa463defc5 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 16 Mar 2018 16:43:49 +0800 Subject: [PATCH] Wait by stream --- paddle/fluid/framework/parallel_executor.cc | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 5870eac8115..d46adf291b7 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -77,7 +77,7 @@ struct OpHandle { virtual ~OpHandle() {} virtual void Run() { PADDLE_THROW("Not implemented"); } - virtual void Wait() {} + virtual void Wait(platform::DeviceContext *waited_dev) {} }; struct ComputationOpHandle : public OpHandle { @@ -97,13 +97,17 @@ struct ComputationOpHandle : public OpHandle { auto *cur_ctx = dev_ctx_[place_]; for (auto *in : inputs_) { if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) { - in->generated_op_->Wait(); + in->generated_op_->Wait(cur_ctx); } } op_->Run(*scope_, place_); LOG(INFO) << "Done " << this; } + + void Wait(platform::DeviceContext *waited_dev) override { + this->dev_ctx_.at(place_)->Wait(); + } }; struct ScaleLossGradOpHandle : public OpHandle { @@ -136,6 +140,10 @@ struct ScaleLossGradOpHandle : public OpHandle { ->stream()); } } + + void Wait(platform::DeviceContext *waited_dev) override { + this->dev_ctx_.at(place_)->Wait(); + } }; class ParallelExecutorPrivate { @@ -276,6 +284,10 @@ struct NCCLAllReduceOpHandle : public OpHandle { platform::dynload::ncclGroupEnd(); } } + + void Wait(platform::DeviceContext *waited_dev) override { + this->dev_ctx_.at(waited_dev->GetPlace())->Wait(); + } }; ParallelExecutor::ParallelExecutor( -- GitLab