提交 65bc7d17 编写于 作者: Y Yu Yang

Add mtx to ncclAllReduce

上级 d42117e7
...@@ -340,6 +340,8 @@ ncclDataType_t ToNCCLDataType(std::type_index type) { ...@@ -340,6 +340,8 @@ ncclDataType_t ToNCCLDataType(std::type_index type) {
} }
} }
static std::mutex g_nccl_mtx_;
struct NCCLAllReduceOpHandle : public OpHandle { struct NCCLAllReduceOpHandle : public OpHandle {
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
...@@ -361,6 +363,8 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -361,6 +363,8 @@ struct NCCLAllReduceOpHandle : public OpHandle {
int dtype = -1; int dtype = -1;
size_t numel = 0; size_t numel = 0;
std::lock_guard<std::mutex> g(g_nccl_mtx_);
platform::dynload::ncclGroupStart(); platform::dynload::ncclGroupStart();
for (size_t i = 0; i < member_->local_scopes_.size(); ++i) { for (size_t i = 0; i < member_->local_scopes_.size(); ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册