提交 d7badb3e 编写于 作者: Y Yu Yang

Use event to sync stream

上级 3aa7051b
...@@ -315,9 +315,21 @@ ncclDataType_t ToNCCLDataType(std::type_index type) { ...@@ -315,9 +315,21 @@ ncclDataType_t ToNCCLDataType(std::type_index type) {
struct NCCLAllReduceOpHandle : public OpHandle { struct NCCLAllReduceOpHandle : public OpHandle {
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
std::vector<cudaEvent_t> events_;
explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member) explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
: member_(member) {} : member_(member) {
events_.resize(member_->places_.size());
for (auto &ev : events_) {
cudaEventCreateWithFlags(&ev, cudaEventDisableTiming);
}
}
~NCCLAllReduceOpHandle() {
for (auto &ev : events_) {
cudaEventDestroy(ev);
}
}
void Run() override { void Run() override {
if (this->inputs_.size() == 1) { if (this->inputs_.size() == 1) {
...@@ -350,6 +362,7 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -350,6 +362,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
platform::dynload::ncclAllReduce( platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum, buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream()); nccl_ctx.comm, nccl_ctx.stream());
cudaEventRecord(events_[i], nccl_ctx.stream());
} }
platform::dynload::ncclGroupEnd(); platform::dynload::ncclGroupEnd();
...@@ -357,8 +370,19 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -357,8 +370,19 @@ struct NCCLAllReduceOpHandle : public OpHandle {
} }
void Wait(platform::DeviceContext *waited_dev) override { void Wait(platform::DeviceContext *waited_dev) override {
for (auto &pair : member_->communication_streams_) { if (platform::is_cpu_place(
pair.second.ctx_->Wait(); waited_dev->GetPlace())) { // Wait by CPU, just sync stream
for (auto &pair : member_->communication_streams_) {
pair.second.ctx_->Wait();
}
} else {
if (events_.size() > 1) {
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
for (auto &ev : events_) {
cudaStreamWaitEvent(stream, ev, 0);
}
}
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册