提交 8d04d0e2 编写于 作者: Y yi.wu

update

上级 6f010712
...@@ -166,7 +166,7 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -166,7 +166,7 @@ void ParallelExecutor::BCastParamsToGPUs(
void *buffer; void *buffer;
if ((initializing && i == 0) || if ((initializing && i == 0) ||
(!initializing && i == static_cast<size_t>(var_dev_id))) { (!initializing && static_cast<int>(i) == var_dev_id)) {
buffer = const_cast<void *>(main_tensor.data<void>()); buffer = const_cast<void *>(main_tensor.data<void>());
} else { } else {
auto local_scope = member_->local_scopes_[i]; auto local_scope = member_->local_scopes_[i];
...@@ -187,7 +187,7 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -187,7 +187,7 @@ void ParallelExecutor::BCastParamsToGPUs(
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0, platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
nccl_ctx.comm_, nccl_ctx.stream()); nccl_ctx.comm_, nccl_ctx.stream());
} else { } else {
if (static_cast<size_t>(var_dev_id)) { if (var_dev_id >= 0) {
platform::dynload::ncclBcast(buffers[i], numel, data_type, platform::dynload::ncclBcast(buffers[i], numel, data_type,
var_dev_id, nccl_ctx.comm_, var_dev_id, nccl_ctx.comm_,
nccl_ctx.stream()); nccl_ctx.stream());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册