提交 071043c3 编写于 作者: Y Yu Yang

Add paddle enforce

上级 8af57706
...@@ -320,14 +320,14 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -320,14 +320,14 @@ struct NCCLAllReduceOpHandle : public OpHandle {
explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member) explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
: member_(member) { : member_(member) {
for (auto &nccl : member_->communication_streams_) { for (auto &nccl : member_->communication_streams_) {
cudaEventCreate(&events_[nccl.second.device_id()], PADDLE_ENFORCE(cudaEventCreate(&events_[nccl.second.device_id()],
cudaEventDisableTiming); cudaEventDisableTiming));
} }
} }
~NCCLAllReduceOpHandle() { ~NCCLAllReduceOpHandle() {
for (auto &ev : events_) { for (auto &ev : events_) {
cudaEventDestroy(ev.second); PADDLE_ENFORCE(cudaEventDestroy(ev.second));
} }
} }
...@@ -362,7 +362,7 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -362,7 +362,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
platform::dynload::ncclAllReduce( platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum, buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream()); nccl_ctx.comm, nccl_ctx.stream());
cudaEventRecord(events_[dev_id], nccl_ctx.stream()); PADDLE_ENFORCE(cudaEventRecord(events_[dev_id], nccl_ctx.stream()));
} }
platform::dynload::ncclGroupEnd(); platform::dynload::ncclGroupEnd();
...@@ -381,7 +381,7 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -381,7 +381,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
boost::get<platform::CUDAPlace>(waited_dev->GetPlace()).device; boost::get<platform::CUDAPlace>(waited_dev->GetPlace()).device;
auto stream = auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream(); static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
cudaStreamWaitEvent(stream, events_[dev_id], 0); PADDLE_ENFORCE(cudaStreamWaitEvent(stream, events_[dev_id], 0));
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册