diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index bc61b0eacbf6c8a1fd4487ad5a442fed1b536345..02e89d18a8dcdd90372e65645ee30c607ae50727 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -625,19 +625,11 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph, void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( ir::Graph *result, const std::string &loss_grad_name) const { for (size_t i = 0; i < places_.size(); ++i) { -// Insert ScaleCost OpHandle -#ifdef PADDLE_WITH_CUDA - auto *communication_dev_ctx = - nccl_ctxs_ ? nccl_ctxs_->DevCtx(places_[i]) - : platform::DeviceContextPool::Instance().Get(places_[i]); -#else - auto *communication_dev_ctx = - platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); -#endif + // Insert ScaleCost OpHandle + auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]); auto *op_handle = new ScaleLossGradOpHandle( result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), - local_scopes_.size(), local_scopes_[i], places_[i], - communication_dev_ctx); + local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx); result->Get(kGraphOps).emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale