diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 6777aec488d72d2d75872098eb69a42972d94c57..f7dc833937162f5892429e4817ce75f0b458b9f0 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -358,7 +358,7 @@ struct NCCLAllReduceOpHandle : public OpHandle { } auto &nccl_ctx = member_->communication_streams_.at(dev_id); - + cudaSetDevice(dev_id); platform::dynload::ncclAllReduce( buffer, buffer, numel, static_cast(dtype), ncclSum, nccl_ctx.comm, nccl_ctx.stream()); @@ -519,7 +519,6 @@ void ParallelExecutor::ConstructDependencyGraph( var.name_ = og; var.version_ = vars.size() - 1; op_handle->outputs_.emplace_back(&var); - op_handle->dev_ctx_[p] = member_->CommunicationDevCtx(p); } }