提交 d7badb3e 编写于 作者: Y Yu Yang

Use event to sync stream

上级 3aa7051b
......@@ -315,9 +315,21 @@ ncclDataType_t ToNCCLDataType(std::type_index type) {
struct NCCLAllReduceOpHandle : public OpHandle {
ParallelExecutorPrivate *member_;
std::vector<cudaEvent_t> events_;
explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
: member_(member) {}
: member_(member) {
events_.resize(member_->places_.size());
for (auto &ev : events_) {
cudaEventCreateWithFlags(&ev, cudaEventDisableTiming);
}
}
~NCCLAllReduceOpHandle() {
for (auto &ev : events_) {
cudaEventDestroy(ev);
}
}
void Run() override {
if (this->inputs_.size() == 1) {
......@@ -350,6 +362,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream());
cudaEventRecord(events_[i], nccl_ctx.stream());
}
platform::dynload::ncclGroupEnd();
......@@ -357,8 +370,19 @@ struct NCCLAllReduceOpHandle : public OpHandle {
}
void Wait(platform::DeviceContext *waited_dev) override {
for (auto &pair : member_->communication_streams_) {
pair.second.ctx_->Wait();
if (platform::is_cpu_place(
waited_dev->GetPlace())) { // Wait by CPU, just sync stream
for (auto &pair : member_->communication_streams_) {
pair.second.ctx_->Wait();
}
} else {
if (events_.size() > 1) {
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
for (auto &ev : events_) {
cudaStreamWaitEvent(stream, ev, 0);
}
}
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册