提交 2ecef6f2 编写于 作者: M Megvii Engine Team 提交者: liuqingyi

refactor(nccl): disable cudaStreamSync in nccl opr

GitOrigin-RevId: a2604c9d039c05240be86bcede641f8b730e980e
上级 e14e4f84
......@@ -55,7 +55,6 @@ Status NcclCommunicator::send(const void* sendbuff, size_t len, uint32_t rank,
cudaStream_t stream = static_cast<CudaContext*>(ctx.get())->get_stream();
// perform nccl send synchronously
NCCL_CHECK(ncclSend(sendbuff, len, ncclChar, rank, m_comm, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
return MEGRAY_OK;
}
......@@ -66,7 +65,6 @@ Status NcclCommunicator::recv(void* recvbuff, size_t len, uint32_t rank,
cudaStream_t stream = static_cast<CudaContext*>(ctx.get())->get_stream();
// perform nccl send synchronously
NCCL_CHECK(ncclRecv(recvbuff, len, ncclChar, rank, m_comm, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
return MEGRAY_OK;
}
......@@ -87,8 +85,6 @@ Status NcclCommunicator::scatter(const void* sendbuff, void* recvbuff,
}
NCCL_CHECK(ncclRecv(recvbuff, recvlen, nccl_dtype, root, m_comm, stream));
ncclGroupEnd();
// cuda stream synchronize
CUDA_CHECK(cudaStreamSynchronize(stream));
return MEGRAY_OK;
}
......@@ -109,8 +105,6 @@ Status NcclCommunicator::gather(const void* sendbuff, void* recvbuff,
}
NCCL_CHECK(ncclSend(sendbuff, sendlen, nccl_dtype, root, m_comm, stream));
ncclGroupEnd();
// cuda stream synchronize
CUDA_CHECK(cudaStreamSynchronize(stream));
return MEGRAY_OK;
}
......@@ -130,8 +124,6 @@ Status NcclCommunicator::all_to_all(const void* sendbuff, void* recvbuff,
NCCL_CHECK(ncclRecv((void*)q, len, nccl_dtype, r, m_comm, stream));
}
ncclGroupEnd();
// cuda stream synchronize
CUDA_CHECK(cudaStreamSynchronize(stream));
return MEGRAY_OK;
}
......@@ -143,7 +135,6 @@ Status NcclCommunicator::all_gather(const void* sendbuff, void* recvbuff, size_t
// perform all gather synchronously
NCCL_CHECK(ncclAllGather(sendbuff, recvbuff, sendlen, get_nccl_dtype(dtype),
m_comm, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
return MEGRAY_OK;
}
......@@ -155,7 +146,6 @@ Status NcclCommunicator::all_reduce(const void* sendbuff, void* recvbuff, size_t
// perform all reduce synchronously
NCCL_CHECK(ncclAllReduce(sendbuff, recvbuff, len, get_nccl_dtype(dtype),
get_nccl_reduce_op(op), m_comm, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
return MEGRAY_OK;
}
......@@ -167,7 +157,6 @@ Status NcclCommunicator::reduce_scatter(const void* sendbuff, void* recvbuff, si
// perform reduce scatter synchronously
NCCL_CHECK(ncclReduceScatter(sendbuff, recvbuff, recvlen, get_nccl_dtype(dtype),
get_nccl_reduce_op(op), m_comm, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
return MEGRAY_OK;
}
......@@ -179,7 +168,6 @@ Status NcclCommunicator::broadcast(const void* sendbuff, void* recvbuff, size_t
// perform broadcast synchronously
NCCL_CHECK(ncclBroadcast(sendbuff, recvbuff, len, get_nccl_dtype(dtype), root,
m_comm, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
return MEGRAY_OK;
}
......@@ -191,7 +179,6 @@ Status NcclCommunicator::reduce(const void* sendbuff, void* recvbuff, size_t len
// perform reduce synchronously
NCCL_CHECK(ncclReduce(sendbuff, recvbuff, len, get_nccl_dtype(dtype),
get_nccl_reduce_op(op), root, m_comm, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
return MEGRAY_OK;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册