提交 b2c7a9b8 编写于 作者: Y Yu Yang

Wait by stream

上级 e8a7e5d1
...@@ -77,7 +77,7 @@ struct OpHandle { ...@@ -77,7 +77,7 @@ struct OpHandle {
virtual ~OpHandle() {} virtual ~OpHandle() {}
virtual void Run() { PADDLE_THROW("Not implemented"); } virtual void Run() { PADDLE_THROW("Not implemented"); }
virtual void Wait() {} virtual void Wait(platform::DeviceContext *waited_dev) {}
}; };
struct ComputationOpHandle : public OpHandle { struct ComputationOpHandle : public OpHandle {
...@@ -97,13 +97,17 @@ struct ComputationOpHandle : public OpHandle { ...@@ -97,13 +97,17 @@ struct ComputationOpHandle : public OpHandle {
auto *cur_ctx = dev_ctx_[place_]; auto *cur_ctx = dev_ctx_[place_];
for (auto *in : inputs_) { for (auto *in : inputs_) {
if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) { if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
in->generated_op_->Wait(); in->generated_op_->Wait(cur_ctx);
} }
} }
op_->Run(*scope_, place_); op_->Run(*scope_, place_);
LOG(INFO) << "Done " << this; LOG(INFO) << "Done " << this;
} }
void Wait(platform::DeviceContext *waited_dev) override {
this->dev_ctx_.at(place_)->Wait();
}
}; };
struct ScaleLossGradOpHandle : public OpHandle { struct ScaleLossGradOpHandle : public OpHandle {
...@@ -136,6 +140,10 @@ struct ScaleLossGradOpHandle : public OpHandle { ...@@ -136,6 +140,10 @@ struct ScaleLossGradOpHandle : public OpHandle {
->stream()); ->stream());
} }
} }
void Wait(platform::DeviceContext *waited_dev) override {
this->dev_ctx_.at(place_)->Wait();
}
}; };
class ParallelExecutorPrivate { class ParallelExecutorPrivate {
...@@ -276,6 +284,10 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -276,6 +284,10 @@ struct NCCLAllReduceOpHandle : public OpHandle {
platform::dynload::ncclGroupEnd(); platform::dynload::ncclGroupEnd();
} }
} }
void Wait(platform::DeviceContext *waited_dev) override {
this->dev_ctx_.at(waited_dev->GetPlace())->Wait();
}
}; };
ParallelExecutor::ParallelExecutor( ParallelExecutor::ParallelExecutor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册