提交 b2c7a9b8 编写于 作者: Y Yu Yang

Wait by stream

上级 e8a7e5d1
......@@ -77,7 +77,7 @@ struct OpHandle {
virtual ~OpHandle() {}
virtual void Run() { PADDLE_THROW("Not implemented"); }
virtual void Wait() {}
virtual void Wait(platform::DeviceContext *waited_dev) {}
};
struct ComputationOpHandle : public OpHandle {
......@@ -97,13 +97,17 @@ struct ComputationOpHandle : public OpHandle {
auto *cur_ctx = dev_ctx_[place_];
for (auto *in : inputs_) {
if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
in->generated_op_->Wait();
in->generated_op_->Wait(cur_ctx);
}
}
op_->Run(*scope_, place_);
LOG(INFO) << "Done " << this;
}
void Wait(platform::DeviceContext *waited_dev) override {
this->dev_ctx_.at(place_)->Wait();
}
};
struct ScaleLossGradOpHandle : public OpHandle {
......@@ -136,6 +140,10 @@ struct ScaleLossGradOpHandle : public OpHandle {
->stream());
}
}
void Wait(platform::DeviceContext *waited_dev) override {
this->dev_ctx_.at(place_)->Wait();
}
};
class ParallelExecutorPrivate {
......@@ -276,6 +284,10 @@ struct NCCLAllReduceOpHandle : public OpHandle {
platform::dynload::ncclGroupEnd();
}
}
void Wait(platform::DeviceContext *waited_dev) override {
this->dev_ctx_.at(waited_dev->GetPlace())->Wait();
}
};
ParallelExecutor::ParallelExecutor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册