提交 8af57706 编写于 作者: Y Yu Yang

Only wait same device

上级 29cc9f30
......@@ -315,19 +315,19 @@ ncclDataType_t ToNCCLDataType(std::type_index type) {
struct NCCLAllReduceOpHandle : public OpHandle {
ParallelExecutorPrivate *member_;
std::vector<cudaEvent_t> events_;
std::unordered_map<int, cudaEvent_t> events_;
explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
: member_(member) {
events_.resize(member_->places_.size());
for (auto &ev : events_) {
cudaEventCreateWithFlags(&ev, cudaEventDisableTiming);
for (auto &nccl : member_->communication_streams_) {
cudaEventCreate(&events_[nccl.second.device_id()],
cudaEventDisableTiming);
}
}
~NCCLAllReduceOpHandle() {
for (auto &ev : events_) {
cudaEventDestroy(ev);
cudaEventDestroy(ev.second);
}
}
......@@ -362,7 +362,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream());
cudaEventRecord(events_[i], nccl_ctx.stream());
cudaEventRecord(events_[dev_id], nccl_ctx.stream());
}
platform::dynload::ncclGroupEnd();
......@@ -377,11 +377,11 @@ struct NCCLAllReduceOpHandle : public OpHandle {
}
} else {
if (events_.size() > 1) {
int dev_id =
boost::get<platform::CUDAPlace>(waited_dev->GetPlace()).device;
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
for (auto &ev : events_) {
cudaStreamWaitEvent(stream, ev, 0);
}
cudaStreamWaitEvent(stream, events_[dev_id], 0);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册