提交 29cc9f30 编写于 作者: Y Yu Yang

SetDev for nccl

上级 d7badb3e
......@@ -358,7 +358,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
}
auto &nccl_ctx = member_->communication_streams_.at(dev_id);
cudaSetDevice(dev_id);
platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream());
......@@ -519,7 +519,6 @@ void ParallelExecutor::ConstructDependencyGraph(
var.name_ = og;
var.version_ = vars.size() - 1;
op_handle->outputs_.emplace_back(&var);
op_handle->dev_ctx_[p] = member_->CommunicationDevCtx(p);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册