提交 b94ffacb 编写于 作者: Y Yu Yang

SetDev

上级 99f85a9f
...@@ -132,12 +132,12 @@ struct ScaleLossGradOpHandle : public OpHandle { ...@@ -132,12 +132,12 @@ struct ScaleLossGradOpHandle : public OpHandle {
scope_(scope), scope_(scope),
place_(place) { place_(place) {
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device); cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
// Must set device before create event
PADDLE_ENFORCE(cudaEventCreateWithFlags(&ev_, cudaEventDisableTiming)); PADDLE_ENFORCE(cudaEventCreateWithFlags(&ev_, cudaEventDisableTiming));
VLOG(3) << "Create " << ev_;
} }
~ScaleLossGradOpHandle() { ~ScaleLossGradOpHandle() {
VLOG(3) << "Destroy " << ev_; cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
PADDLE_ENFORCE(cudaEventDestroy(ev_)); PADDLE_ENFORCE(cudaEventDestroy(ev_));
} }
...@@ -339,13 +339,15 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -339,13 +339,15 @@ struct NCCLAllReduceOpHandle : public OpHandle {
explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member) explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
: member_(member) { : member_(member) {
for (auto &nccl : member_->communication_streams_) { for (auto &nccl : member_->communication_streams_) {
PADDLE_ENFORCE(cudaEventCreate(&events_[nccl.second.device_id()], int dev_id = nccl.second.device_id();
cudaEventDisableTiming)); cudaSetDevice(dev_id);
PADDLE_ENFORCE(cudaEventCreate(&events_[dev_id], cudaEventDisableTiming));
} }
} }
~NCCLAllReduceOpHandle() { ~NCCLAllReduceOpHandle() {
for (auto &ev : events_) { for (auto &ev : events_) {
cudaSetDevice(ev.first);
PADDLE_ENFORCE(cudaEventDestroy(ev.second)); PADDLE_ENFORCE(cudaEventDestroy(ev.second));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册