提交 b94ffacb 编写于 作者: Y Yu Yang

SetDev

上级 99f85a9f
......@@ -132,12 +132,12 @@ struct ScaleLossGradOpHandle : public OpHandle {
scope_(scope),
place_(place) {
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
// Must set device before create event
PADDLE_ENFORCE(cudaEventCreateWithFlags(&ev_, cudaEventDisableTiming));
VLOG(3) << "Create " << ev_;
}
~ScaleLossGradOpHandle() {
VLOG(3) << "Destroy " << ev_;
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
PADDLE_ENFORCE(cudaEventDestroy(ev_));
}
......@@ -339,13 +339,15 @@ struct NCCLAllReduceOpHandle : public OpHandle {
explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
: member_(member) {
for (auto &nccl : member_->communication_streams_) {
PADDLE_ENFORCE(cudaEventCreate(&events_[nccl.second.device_id()],
cudaEventDisableTiming));
int dev_id = nccl.second.device_id();
cudaSetDevice(dev_id);
PADDLE_ENFORCE(cudaEventCreate(&events_[dev_id], cudaEventDisableTiming));
}
}
~NCCLAllReduceOpHandle() {
for (auto &ev : events_) {
cudaSetDevice(ev.first);
PADDLE_ENFORCE(cudaEventDestroy(ev.second));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册