提交 c372ce28 编写于 作者: Y Yu Yang

Add event for computational op

上级 48619bc9
...@@ -92,12 +92,22 @@ struct ComputationOpHandle : public OpHandle { ...@@ -92,12 +92,22 @@ struct ComputationOpHandle : public OpHandle {
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
Scope *scope_; Scope *scope_;
platform::Place place_; platform::Place place_;
cudaEvent_t event_;
explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope, explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
platform::Place place) platform::Place place)
: op_(framework::OpRegistry::CreateOp(op_desc)), : op_(framework::OpRegistry::CreateOp(op_desc)),
scope_(scope), scope_(scope),
place_(place) {} place_(place) {
if (platform::is_gpu_place(place)) {
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
cudaEventCreateWithFlags(&event_, cudaEventDisableTiming);
}
}
~ComputationOpHandle() {
// FIXME: Destroy Event
}
void Run() override { void Run() override {
// Wait other op if necessary // Wait other op if necessary
...@@ -113,10 +123,22 @@ struct ComputationOpHandle : public OpHandle { ...@@ -113,10 +123,22 @@ struct ComputationOpHandle : public OpHandle {
} }
op_->Run(*scope_, place_); op_->Run(*scope_, place_);
if (platform::is_gpu_place(place_)) {
auto stream = static_cast<platform::CUDADeviceContext *>(dev_ctx_[place_])
->stream();
PADDLE_ENFORCE(cudaEventRecord(event_, stream));
}
} }
void Wait(platform::DeviceContext *waited_dev) override { void Wait(platform::DeviceContext *waited_dev) override {
if (platform::is_cpu_place(waited_dev->GetPlace()) ||
platform::is_cpu_place(place_)) {
this->dev_ctx_.at(place_)->Wait(); this->dev_ctx_.at(place_)->Wait();
} else {
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, event_, 0));
}
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册