提交 8af57706 编写于 作者: Y Yu Yang

Only wait same device

上级 29cc9f30
...@@ -315,19 +315,19 @@ ncclDataType_t ToNCCLDataType(std::type_index type) { ...@@ -315,19 +315,19 @@ ncclDataType_t ToNCCLDataType(std::type_index type) {
struct NCCLAllReduceOpHandle : public OpHandle { struct NCCLAllReduceOpHandle : public OpHandle {
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
std::vector<cudaEvent_t> events_; std::unordered_map<int, cudaEvent_t> events_;
explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member) explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
: member_(member) { : member_(member) {
events_.resize(member_->places_.size()); for (auto &nccl : member_->communication_streams_) {
for (auto &ev : events_) { cudaEventCreate(&events_[nccl.second.device_id()],
cudaEventCreateWithFlags(&ev, cudaEventDisableTiming); cudaEventDisableTiming);
} }
} }
~NCCLAllReduceOpHandle() { ~NCCLAllReduceOpHandle() {
for (auto &ev : events_) { for (auto &ev : events_) {
cudaEventDestroy(ev); cudaEventDestroy(ev.second);
} }
} }
...@@ -362,7 +362,7 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -362,7 +362,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
platform::dynload::ncclAllReduce( platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum, buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream()); nccl_ctx.comm, nccl_ctx.stream());
cudaEventRecord(events_[i], nccl_ctx.stream()); cudaEventRecord(events_[dev_id], nccl_ctx.stream());
} }
platform::dynload::ncclGroupEnd(); platform::dynload::ncclGroupEnd();
...@@ -377,11 +377,11 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -377,11 +377,11 @@ struct NCCLAllReduceOpHandle : public OpHandle {
} }
} else { } else {
if (events_.size() > 1) { if (events_.size() > 1) {
int dev_id =
boost::get<platform::CUDAPlace>(waited_dev->GetPlace()).device;
auto stream = auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream(); static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
for (auto &ev : events_) { cudaStreamWaitEvent(stream, events_[dev_id], 0);
cudaStreamWaitEvent(stream, ev, 0);
}
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册