提交 29cc9f30 编写于 作者: Y Yu Yang

SetDev for nccl

上级 d7badb3e
...@@ -358,7 +358,7 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -358,7 +358,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
} }
auto &nccl_ctx = member_->communication_streams_.at(dev_id); auto &nccl_ctx = member_->communication_streams_.at(dev_id);
cudaSetDevice(dev_id);
platform::dynload::ncclAllReduce( platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum, buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream()); nccl_ctx.comm, nccl_ctx.stream());
...@@ -519,7 +519,6 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -519,7 +519,6 @@ void ParallelExecutor::ConstructDependencyGraph(
var.name_ = og; var.name_ = og;
var.version_ = vars.size() - 1; var.version_ = vars.size() - 1;
op_handle->outputs_.emplace_back(&var); op_handle->outputs_.emplace_back(&var);
op_handle->dev_ctx_[p] = member_->CommunicationDevCtx(p); op_handle->dev_ctx_[p] = member_->CommunicationDevCtx(p);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册