提交 9824e8f3 编写于 作者: Y Yu Yang

Scale loss op use event

上级 071043c3
...@@ -124,12 +124,17 @@ struct ScaleLossGradOpHandle : public OpHandle { ...@@ -124,12 +124,17 @@ struct ScaleLossGradOpHandle : public OpHandle {
float coeff_; float coeff_;
Scope *scope_; Scope *scope_;
platform::Place place_; platform::Place place_;
cudaEvent_t ev_;
explicit ScaleLossGradOpHandle(size_t num_dev, Scope *scope, explicit ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
platform::Place place) platform::Place place)
: coeff_(static_cast<float>(1.0 / num_dev)), : coeff_(static_cast<float>(1.0 / num_dev)),
scope_(scope), scope_(scope),
place_(place) {} place_(place) {
PADDLE_ENFORCE(cudaEventCreateWithFlags(&ev_, cudaEventDisableTiming));
}
~ScaleLossGradOpHandle() { PADDLE_ENFORCE(cudaEventDestroy(ev_)); }
void Run() override { void Run() override {
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_; std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
...@@ -141,16 +146,23 @@ struct ScaleLossGradOpHandle : public OpHandle { ...@@ -141,16 +146,23 @@ struct ScaleLossGradOpHandle : public OpHandle {
if (platform::is_cpu_place(place_)) { if (platform::is_cpu_place(place_)) {
*tmp = coeff_; *tmp = coeff_;
} else { } else {
memory::Copy( auto stream =
boost::get<platform::CUDAPlace>(place_), tmp, platform::CPUPlace(),
&coeff_, sizeof(float),
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_[place_]) static_cast<platform::CUDADeviceContext *>(this->dev_ctx_[place_])
->stream()); ->stream();
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
platform::CPUPlace(), &coeff_, sizeof(float), stream);
PADDLE_ENFORCE(cudaEventRecord(ev_, stream));
} }
} }
void Wait(platform::DeviceContext *waited_dev) override { void Wait(platform::DeviceContext *waited_dev) override {
this->dev_ctx_.at(place_)->Wait(); if (platform::is_cpu_place(waited_dev->GetPlace())) {
this->dev_ctx_.at(place_)->Wait();
} else {
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev_, 0));
}
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册